From 60c7222fbe075d87c68628d480f7a834b46421b2 Mon Sep 17 00:00:00 2001 From: "Alinson S. Xavier" Date: Thu, 1 Feb 2024 10:18:24 -0600 Subject: [PATCH] Cuts: Call set_cuts instead of setting cuts_aot_ directly --- miplearn/components/cuts/mem.py | 5 +++-- miplearn/solvers/abstract.py | 3 +++ tests/components/cuts/test_mem.py | 8 +++++--- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/miplearn/components/cuts/mem.py b/miplearn/components/cuts/mem.py index 3aa93e2..1c8d3b1 100644 --- a/miplearn/components/cuts/mem.py +++ b/miplearn/components/cuts/mem.py @@ -110,5 +110,6 @@ class MemorizingCutsComponent(_BaseMemorizingConstrComponent): if model.cuts_enforce is None: return assert self.constrs_ is not None - model.cuts_aot_ = self.predict("Predicting cutting planes...", test_h5) - stats["Cuts: AOT"] = len(model.cuts_aot_) + cuts = self.predict("Predicting cutting planes...", test_h5) + model.set_cuts(cuts) + stats["Cuts: AOT"] = len(cuts) diff --git a/miplearn/solvers/abstract.py b/miplearn/solvers/abstract.py index 6f750de..6341c85 100644 --- a/miplearn/solvers/abstract.py +++ b/miplearn/solvers/abstract.py @@ -82,3 +82,6 @@ class AbstractModel(ABC): @abstractmethod def write(self, filename: str) -> None: pass + + def set_cuts(self, cuts: List) -> None: + self.cuts_aot_ = cuts diff --git a/tests/components/cuts/test_mem.py b/tests/components/cuts/test_mem.py index 11797ee..7bebfb0 100644 --- a/tests/components/cuts/test_mem.py +++ b/tests/components/cuts/test_mem.py @@ -50,9 +50,11 @@ def test_mem_component_gp( (x_test,) = clf.predict.call_args.args assert x_test.shape == (1, 50) - # Should set cuts_aot_ - assert model.cuts_aot_ is not None - assert len(model.cuts_aot_) == 256 + # Should call set_cuts + model.set_cuts.assert_called() + (cuts_aot_,) = model.set_cuts.call_args.args + assert cuts_aot_ is not None + assert len(cuts_aot_) == 256 def test_usage_stab(