diff --git a/miplearn/components/objective.py b/miplearn/components/objective.py index 2fc5afd..33bc062 100644 --- a/miplearn/components/objective.py +++ b/miplearn/components/objective.py @@ -95,13 +95,12 @@ class ObjectiveValueComponent(Component): # Labels y: Dict[Hashable, List[List[float]]] = {} - if sample.after_mip is not None: - mip_stats = sample.after_mip.mip_solve - assert mip_stats is not None - if mip_stats.mip_lower_bound is not None: - y["Lower bound"] = [[mip_stats.mip_lower_bound]] - if mip_stats.mip_upper_bound is not None: - y["Upper bound"] = [[mip_stats.mip_upper_bound]] + mip_lower_bound = sample.get("mip_lower_bound") + mip_upper_bound = sample.get("mip_upper_bound") + if mip_lower_bound is not None: + y["Lower bound"] = [[mip_lower_bound]] + if mip_upper_bound is not None: + y["Upper bound"] = [[mip_upper_bound]] return x, y @@ -111,9 +110,6 @@ class ObjectiveValueComponent(Component): instance: Instance, sample: Sample, ) -> Dict[Hashable, Dict[str, float]]: - assert sample.after_mip is not None - assert sample.after_mip.mip_solve is not None - def compare(y_pred: float, y_actual: float) -> Dict[str, float]: err = np.round(abs(y_pred - y_actual), 8) return { @@ -125,8 +121,8 @@ class ObjectiveValueComponent(Component): result: Dict[Hashable, Dict[str, float]] = {} pred = self.sample_predict(sample) - actual_ub = sample.after_mip.mip_solve.mip_upper_bound - actual_lb = sample.after_mip.mip_solve.mip_lower_bound + actual_ub = sample.get("mip_upper_bound") + actual_lb = sample.get("mip_lower_bound") if actual_ub is not None: result["Upper bound"] = compare(pred["Upper bound"], actual_ub) if actual_lb is not None: diff --git a/miplearn/components/primal.py b/miplearn/components/primal.py index c37701d..9274745 100644 --- a/miplearn/components/primal.py +++ b/miplearn/components/primal.py @@ -155,6 +155,7 @@ class PrimalSolutionComponent(Component): assert sample.after_load.variables is not None assert sample.after_load.variables.names is not None assert sample.after_load.variables.categories is not None + mip_var_values = sample.get("mip_var_values") for (i, var_name) in enumerate(sample.after_load.variables.names): # Initialize categories @@ -174,10 +175,8 @@ class PrimalSolutionComponent(Component): x[category].append(features) # Labels - if sample.after_mip is not None: - assert sample.after_mip.variables is not None - assert sample.after_mip.variables.values is not None - opt_value = sample.after_mip.variables.values[i] + if mip_var_values is not None: + opt_value = mip_var_values[i] assert opt_value is not None assert 0.0 - 1e-5 <= opt_value <= 1.0 + 1e-5, ( f"Variable {var_name} has non-binary value {opt_value} in the " @@ -194,14 +193,13 @@ class PrimalSolutionComponent(Component): _: Optional[Instance], sample: Sample, ) -> Dict[Hashable, Dict[str, float]]: - assert sample.after_mip is not None - assert sample.after_mip.variables is not None - assert sample.after_mip.variables.values is not None - assert sample.after_mip.variables.names is not None + mip_var_values = sample.get("mip_var_values") + var_names = sample.get("var_names") + assert mip_var_values is not None + assert var_names is not None solution_actual = { - var_name: sample.after_mip.variables.values[i] - for (i, var_name) in enumerate(sample.after_mip.variables.names) + var_name: mip_var_values[i] for (i, var_name) in enumerate(var_names) } solution_pred = self.sample_predict(sample) vars_all, vars_one, vars_zero = set(), set(), set() diff --git a/miplearn/features.py b/miplearn/features.py index b2dca96..d913a00 100644 --- a/miplearn/features.py +++ b/miplearn/features.py @@ -148,7 +148,6 @@ class Sample: self, after_load: Optional[Features] = None, after_lp: Optional[Features] = None, - after_mip: Optional[Features] = None, data: Optional[Dict[str, Any]] = None, ) -> None: if data is None: @@ -156,7 +155,6 @@ class Sample: self._data: Dict[str, Any] = data self.after_load = after_load self.after_lp = after_lp - self.after_mip = after_mip def get(self, key: str) -> Optional[Any]: if key in self._data: diff --git a/miplearn/solvers/learning.py b/miplearn/solvers/learning.py index 9917566..cd5fb4c 100644 --- a/miplearn/solvers/learning.py +++ b/miplearn/solvers/learning.py @@ -210,6 +210,7 @@ class LearningSolver: # ------------------------------------------------------- logger.info("Extracting features (after-lp)...") initial_time = time.time() + self.extractor.extract_after_lp_features(self.internal_solver, sample) features = self.extractor.extract( instance, self.internal_solver, @@ -219,6 +220,8 @@ class LearningSolver: "Features (after-lp) extracted in %.2f seconds" % (time.time() - initial_time) ) + for (k, v) in lp_stats.__dict__.items(): + sample.put(k, v) features.lp_solve = lp_stats sample.after_lp = features @@ -282,17 +285,13 @@ class LearningSolver: # ------------------------------------------------------- logger.info("Extracting features (after-mip)...") initial_time = time.time() - features = self.extractor.extract( - instance, - self.internal_solver, - with_static=False, - ) + self.extractor.extract_after_mip_features(self.internal_solver, sample) + for (k, v) in mip_stats.__dict__.items(): + sample.put(k, v) logger.info( "Features (after-mip) extracted in %.2f seconds" % (time.time() - initial_time) ) - features.mip_solve = mip_stats - sample.after_mip = features # After-solve callbacks # ------------------------------------------------------- diff --git a/tests/components/test_objective.py b/tests/components/test_objective.py index bba86a7..6bb35ad 100644 --- a/tests/components/test_objective.py +++ b/tests/components/test_objective.py @@ -25,12 +25,10 @@ def sample() -> Sample: after_lp=Features( lp_solve=LPSolveStats(), ), - after_mip=Features( - mip_solve=MIPSolveStats( - mip_lower_bound=1.0, - mip_upper_bound=2.0, - ) - ), + data={ + "mip_lower_bound": 1.0, + "mip_upper_bound": 2.0, + }, ) sample.after_load.instance.to_list = Mock(return_value=[1.0, 2.0]) # type: ignore sample.after_lp.lp_solve.to_list = Mock(return_value=[3.0]) # type: ignore diff --git a/tests/components/test_primal.py b/tests/components/test_primal.py index 1f83bc7..acd4ef6 100644 --- a/tests/components/test_primal.py +++ b/tests/components/test_primal.py @@ -36,12 +36,10 @@ def sample() -> Sample: after_lp=Features( variables=VariableFeatures(), ), - after_mip=Features( - variables=VariableFeatures( - names=["x[0]", "x[1]", "x[2]", "x[3]"], - values=[0.0, 1.0, 1.0, 0.0], - ) - ), + data={ + "var_names": ["x[0]", "x[1]", "x[2]", "x[3]"], + "mip_var_values": [0.0, 1.0, 1.0, 0.0], + }, ) sample.after_load.instance.to_list = Mock(return_value=[5.0]) # type: ignore sample.after_load.variables.to_list = Mock( # type:ignore diff --git a/tests/problems/test_tsp.py b/tests/problems/test_tsp.py index cf4f98b..16b9628 100644 --- a/tests/problems/test_tsp.py +++ b/tests/problems/test_tsp.py @@ -41,14 +41,9 @@ def test_instance() -> None: solver.solve(instance) assert len(instance.get_samples()) == 1 sample = instance.get_samples()[0] - assert sample.after_mip is not None - features = sample.after_mip - assert features is not None - assert features.variables is not None - assert features.variables.values == [1.0, 0.0, 1.0, 1.0, 0.0, 1.0] - assert features.mip_solve is not None - assert features.mip_solve.mip_lower_bound == 4.0 - assert features.mip_solve.mip_upper_bound == 4.0 + assert sample.get("mip_var_values") == [1.0, 0.0, 1.0, 1.0, 0.0, 1.0] + assert sample.get("mip_lower_bound") == 4.0 + assert sample.get("mip_upper_bound") == 4.0 def test_subtour() -> None: @@ -73,10 +68,7 @@ def test_subtour() -> None: lazy_enforced = sample.get("lazy_enforced") assert lazy_enforced is not None assert len(lazy_enforced) > 0 - assert sample.after_mip is not None - features = sample.after_mip - assert features.variables is not None - assert features.variables.values == [ + assert sample.get("mip_var_values") == [ 1.0, 0.0, 0.0, diff --git a/tests/solvers/test_learning_solver.py b/tests/solvers/test_learning_solver.py index 19ad2b2..06b94db 100644 --- a/tests/solvers/test_learning_solver.py +++ b/tests/solvers/test_learning_solver.py @@ -38,25 +38,18 @@ def test_learning_solver( assert len(instance.get_samples()) > 0 sample = instance.get_samples()[0] - after_mip = sample.after_mip - assert after_mip is not None - assert after_mip.variables is not None - assert after_mip.variables.values == [1.0, 0.0, 1.0, 1.0, 61.0] - assert after_mip.mip_solve is not None - assert after_mip.mip_solve.mip_lower_bound == 1183.0 - assert after_mip.mip_solve.mip_upper_bound == 1183.0 - assert after_mip.mip_solve.mip_log is not None - assert len(after_mip.mip_solve.mip_log) > 100 - - after_lp = sample.after_lp - assert after_lp is not None - assert after_lp.variables is not None - assert_equals(after_lp.variables.values, [1.0, 0.923077, 1.0, 0.0, 67.0]) - assert after_lp.lp_solve is not None - assert after_lp.lp_solve.lp_value is not None - assert round(after_lp.lp_solve.lp_value, 3) == 1287.923 - assert after_lp.lp_solve.lp_log is not None - assert len(after_lp.lp_solve.lp_log) > 100 + assert sample.get("mip_var_values") == [1.0, 0.0, 1.0, 1.0, 61.0] + assert sample.get("mip_lower_bound") == 1183.0 + assert sample.get("mip_upper_bound") == 1183.0 + mip_log = sample.get("mip_log") + assert mip_log is not None + assert len(mip_log) > 100 + + assert_equals(sample.get("lp_var_values"), [1.0, 0.923077, 1.0, 0.0, 67.0]) + assert_equals(sample.get("lp_value"), 1287.923077) + lp_log = sample.get("lp_log") + assert lp_log is not None + assert len(lp_log) > 100 solver.fit([instance], n_jobs=4) solver.solve(instance)