From 80281df8d82e05ef53be2fc0b5e548d55361d3d7 Mon Sep 17 00:00:00 2001 From: "Alinson S. Xavier" Date: Tue, 29 Jun 2021 16:49:24 -0500 Subject: [PATCH] Replace instance.samples by instance.get/push_sample --- miplearn/components/component.py | 5 +++-- miplearn/instance/base.py | 8 +++++++- miplearn/instance/picklegz.py | 13 ++++++++++-- miplearn/solvers/learning.py | 2 +- tests/components/test_dynamic_lazy.py | 24 +++++++++------------- tests/components/test_dynamic_user_cuts.py | 2 +- tests/components/test_static_lazy.py | 4 ++-- tests/problems/test_tsp.py | 14 +++++++------ tests/solvers/test_learning_solver.py | 10 ++++----- 9 files changed, 48 insertions(+), 34 deletions(-) diff --git a/miplearn/components/component.py b/miplearn/components/component.py index c02628a..cf7e104 100644 --- a/miplearn/components/component.py +++ b/miplearn/components/component.py @@ -187,13 +187,14 @@ class Component: instances: List[Instance], n_jobs: int = 1, ) -> None: + # Part I: Pre-fit def _pre_sample_xy(instance: Instance) -> Dict: pre_instance: Dict = {} for (cidx, comp) in enumerate(components): pre_instance[cidx] = [] instance.load() - for sample in instance.samples: + for sample in instance.get_samples(): for (cidx, comp) in enumerate(components): pre_instance[cidx].append(comp.pre_sample_xy(instance, sample)) instance.free() @@ -219,7 +220,7 @@ class Component: x_instance[cidx] = {} y_instance[cidx] = {} instance.load() - for sample in instance.samples: + for sample in instance.get_samples(): for (cidx, comp) in enumerate(components): x = x_instance[cidx] y = y_instance[cidx] diff --git a/miplearn/instance/base.py b/miplearn/instance/base.py index 4b4ea86..c14df41 100644 --- a/miplearn/instance/base.py +++ b/miplearn/instance/base.py @@ -28,7 +28,7 @@ class Instance(ABC): """ def __init__(self) -> None: - self.samples: List[Sample] = [] + self._samples: List[Sample] = [] @abstractmethod def to_model(self) -> Any: @@ -189,3 +189,9 @@ class Instance(ABC): Save any pending changes made to the instance to the underlying data store. """ pass + + def get_samples(self) -> List[Sample]: + return self._samples + + def push_sample(self, sample: Sample) -> None: + self._samples.append(sample) diff --git a/miplearn/instance/picklegz.py b/miplearn/instance/picklegz.py index 3569299..9cb4e2e 100644 --- a/miplearn/instance/picklegz.py +++ b/miplearn/instance/picklegz.py @@ -10,6 +10,7 @@ from typing import Optional, Any, List, Hashable, cast, IO, TYPE_CHECKING, Dict from overrides import overrides +from miplearn.features import Sample from miplearn.instance.base import Instance if TYPE_CHECKING: @@ -120,18 +121,26 @@ class PickleGzInstance(Instance): obj = read_pickle_gz(self.filename) assert isinstance(obj, Instance) self.instance = obj - self.samples = self.instance.samples @overrides def free(self) -> None: self.instance = None # type: ignore - self.samples = None # type: ignore gc.collect() @overrides def flush(self) -> None: write_pickle_gz(self.instance, self.filename) + @overrides + def get_samples(self) -> List[Sample]: + assert self.instance is not None + return self.instance.get_samples() + + @overrides + def push_sample(self, sample: Sample) -> None: + assert self.instance is not None + self.instance.push_sample(sample) + def write_pickle_gz(obj: Any, filename: str) -> None: os.makedirs(os.path.dirname(filename), exist_ok=True) diff --git a/miplearn/solvers/learning.py b/miplearn/solvers/learning.py index a6eb395..e7e07ee 100644 --- a/miplearn/solvers/learning.py +++ b/miplearn/solvers/learning.py @@ -150,7 +150,7 @@ class LearningSolver: # Initialize training sample # ------------------------------------------------------- sample = Sample() - instance.samples.append(sample) + instance.push_sample(sample) # Initialize stats # ------------------------------------------------------- diff --git a/tests/components/test_dynamic_lazy.py b/tests/components/test_dynamic_lazy.py index f33185c..bd43eaa 100644 --- a/tests/components/test_dynamic_lazy.py +++ b/tests/components/test_dynamic_lazy.py @@ -25,7 +25,7 @@ E = 0.1 @pytest.fixture def training_instances() -> List[Instance]: instances = [cast(Instance, Mock(spec=Instance)) for _ in range(2)] - instances[0].samples = [ + samples_0 = [ Sample( after_load=Features(instance=InstanceFeatures()), after_mip=Features(extra={"lazy_enforced": {"c1", "c2"}}), @@ -35,12 +35,9 @@ def training_instances() -> List[Instance]: after_mip=Features(extra={"lazy_enforced": {"c2", "c3"}}), ), ] - instances[0].samples[0].after_load.instance.to_list = Mock( # type: ignore - return_value=[5.0] - ) - instances[0].samples[1].after_load.instance.to_list = Mock( # type: ignore - return_value=[5.0] - ) + samples_0[0].after_load.instance.to_list = Mock(return_value=[5.0]) # type: ignore + samples_0[1].after_load.instance.to_list = Mock(return_value=[5.0]) # type: ignore + instances[0].get_samples = Mock(return_value=samples_0) # type: ignore instances[0].get_constraint_categories = Mock( # type: ignore return_value={ "c1": "type-a", @@ -57,15 +54,14 @@ def training_instances() -> List[Instance]: "c4": [3.0, 4.0], } ) - instances[1].samples = [ + samples_1 = [ Sample( after_load=Features(instance=InstanceFeatures()), after_mip=Features(extra={"lazy_enforced": {"c3", "c4"}}), ) ] - instances[1].samples[0].after_load.instance.to_list = Mock( # type: ignore - return_value=[8.0] - ) + samples_1[0].after_load.instance.to_list = Mock(return_value=[8.0]) # type: ignore + instances[1].get_samples = Mock(return_value=samples_1) # type: ignore instances[1].get_constraint_categories = Mock( # type: ignore return_value={ "c1": None, @@ -97,7 +93,7 @@ def test_sample_xy(training_instances: List[Instance]) -> None: } x_actual, y_actual = comp.sample_xy( training_instances[0], - training_instances[0].samples[0], + training_instances[0].get_samples()[0], ) assert_equals(x_actual, x_expected) assert_equals(y_actual, y_expected) @@ -184,12 +180,12 @@ def test_sample_predict_evaluate(training_instances: List[Instance]) -> None: ) pred = comp.sample_predict( training_instances[0], - training_instances[0].samples[0], + training_instances[0].get_samples()[0], ) assert pred == ["c1", "c4"] ev = comp.sample_evaluate( training_instances[0], - training_instances[0].samples[0], + training_instances[0].get_samples()[0], ) assert ev == { "type-a": classifier_evaluation_dict(tp=1, fp=0, tn=0, fn=1), diff --git a/tests/components/test_dynamic_user_cuts.py b/tests/components/test_dynamic_user_cuts.py index 9323e7e..c46a955 100644 --- a/tests/components/test_dynamic_user_cuts.py +++ b/tests/components/test_dynamic_user_cuts.py @@ -80,7 +80,7 @@ def test_usage( solver: LearningSolver, ) -> None: stats_before = solver.solve(stab_instance) - sample = stab_instance.samples[0] + sample = stab_instance.get_samples()[0] assert sample.after_mip is not None assert sample.after_mip.extra is not None assert len(sample.after_mip.extra["user_cuts_enforced"]) > 0 diff --git a/tests/components/test_static_lazy.py b/tests/components/test_static_lazy.py index be839de..67c9c5d 100644 --- a/tests/components/test_static_lazy.py +++ b/tests/components/test_static_lazy.py @@ -70,7 +70,7 @@ def sample() -> Sample: @pytest.fixture def instance(sample: Sample) -> Instance: instance = Mock(spec=Instance) - instance.samples = [sample] + instance.get_samples = Mock(return_value=[sample]) # type: ignore instance.has_static_lazy_constraints = Mock(return_value=True) return instance @@ -111,7 +111,7 @@ def test_usage_with_solver(instance: Instance) -> None: ) stats: LearningSolveStats = {} - sample = instance.samples[0] + sample = instance.get_samples()[0] assert sample.after_load is not None assert sample.after_mip is not None assert sample.after_mip.extra is not None diff --git a/tests/problems/test_tsp.py b/tests/problems/test_tsp.py index 4018dc4..9b26ccd 100644 --- a/tests/problems/test_tsp.py +++ b/tests/problems/test_tsp.py @@ -39,9 +39,10 @@ def test_instance() -> None: instance = TravelingSalesmanInstance(n_cities, distances) solver = LearningSolver() solver.solve(instance) - assert len(instance.samples) == 1 - assert instance.samples[0].after_mip is not None - features = instance.samples[0].after_mip + assert len(instance.get_samples()) == 1 + sample = instance.get_samples()[0] + assert sample.after_mip is not None + features = sample.after_mip assert features is not None assert features.variables is not None assert features.variables.values == [1.0, 0.0, 1.0, 1.0, 0.0, 1.0] @@ -66,9 +67,10 @@ def test_subtour() -> None: instance = TravelingSalesmanInstance(n_cities, distances) solver = LearningSolver() solver.solve(instance) - assert len(instance.samples) == 1 - assert instance.samples[0].after_mip is not None - features = instance.samples[0].after_mip + assert len(instance.get_samples()) == 1 + sample = instance.get_samples()[0] + assert sample.after_mip is not None + features = sample.after_mip assert features.extra is not None assert "lazy_enforced" in features.extra lazy_enforced = features.extra["lazy_enforced"] diff --git a/tests/solvers/test_learning_solver.py b/tests/solvers/test_learning_solver.py index 8c51a17..19ad2b2 100644 --- a/tests/solvers/test_learning_solver.py +++ b/tests/solvers/test_learning_solver.py @@ -35,8 +35,8 @@ def test_learning_solver( ) solver.solve(instance) - assert len(instance.samples) > 0 - sample = instance.samples[0] + assert len(instance.get_samples()) > 0 + sample = instance.get_samples()[0] after_mip = sample.after_mip assert after_mip is not None @@ -90,7 +90,7 @@ def test_parallel_solve( results = solver.parallel_solve(instances, n_jobs=3) assert len(results) == 10 for instance in instances: - assert len(instance.samples) == 1 + assert len(instance.get_samples()) == 1 def test_solve_fit_from_disk( @@ -109,13 +109,13 @@ def test_solve_fit_from_disk( solver = LearningSolver(solver=internal_solver) solver.solve(instances[0]) instance_loaded = read_pickle_gz(cast(PickleGzInstance, instances[0]).filename) - assert len(instance_loaded.samples) > 0 + assert len(instance_loaded.get_samples()) > 0 # Test: parallel_solve solver.parallel_solve(instances) for instance in instances: instance_loaded = read_pickle_gz(cast(PickleGzInstance, instance).filename) - assert len(instance_loaded.samples) > 0 + assert len(instance_loaded.get_samples()) > 0 # Delete temporary files for instance in instances: