diff --git a/miplearn/components/primal.py b/miplearn/components/primal.py index 80ab3de..78abb0b 100644 --- a/miplearn/components/primal.py +++ b/miplearn/components/primal.py @@ -13,6 +13,7 @@ from typing import ( Any, TYPE_CHECKING, Tuple, + cast, ) import numpy as np @@ -63,29 +64,30 @@ class PrimalSolutionComponent(Component): self._n_one = 0 def before_solve_mip(self, solver, instance, model): - if len(self.thresholds) > 0: - logger.info("Predicting primal solution...") - solution = self.predict(instance) - - # Collect prediction statistics - self._n_free = 0 - self._n_zero = 0 - self._n_one = 0 - for (var, var_dict) in solution.items(): - for (idx, value) in var_dict.items(): - if value is None: - self._n_free += 1 - else: - if value < 0.5: - self._n_zero += 1 - else: - self._n_one += 1 - - # Provide solution to the solver - if self.mode == "heuristic": - solver.internal_solver.fix(solution) - else: - solver.internal_solver.set_warm_start(solution) + pass + # if len(self.thresholds) > 0: + # logger.info("Predicting primal solution...") + # solution = self.predict(instance) + # + # # Collect prediction statistics + # self._n_free = 0 + # self._n_zero = 0 + # self._n_one = 0 + # for (var, var_dict) in solution.items(): + # for (idx, value) in var_dict.items(): + # if value is None: + # self._n_free += 1 + # else: + # if value < 0.5: + # self._n_zero += 1 + # else: + # self._n_one += 1 + # + # # Provide solution to the solver + # if self.mode == "heuristic": + # solver.internal_solver.fix(solution) + # else: + # solver.internal_solver.set_warm_start(solution) def after_solve_mip( self, @@ -214,43 +216,56 @@ class PrimalSolutionComponent(Component): if "Solution" not in sample: return x, y assert sample["Solution"] is not None - for (var, var_dict) in sample["Solution"].items(): - for (idx, opt_value) in var_dict.items(): - assert opt_value is not None - assert 0.0 - 1e-5 <= opt_value <= 1.0 + 1e-5, ( - f"Variable {var} has non-binary value {opt_value} in the optimal " - f"solution. Predicting values of non-binary variables is not " - f"currently supported. Please set its category to None." - ) - category = instance.get_variable_category(var, idx) - if category is None: - continue - if category not in x.keys(): - x[category] = [] - y[category] = [] - features: Any = instance.get_variable_features(var, idx) - assert isinstance(features, list) - if "LP solution" in sample and sample["LP solution"] is not None: - lp_value = sample["LP solution"][var][idx] - if lp_value is not None: - features += [sample["LP solution"][var][idx]] - x[category] += [features] - y[category] += [[opt_value < 0.5, opt_value >= 0.5]] - return x, y + return cast( + Tuple[Dict, Dict], + PrimalSolutionComponent._extract( + instance, + sample, + sample["Solution"], + extract_y=True, + ), + ) @staticmethod def x_sample( instance: Any, sample: TrainingSample, ) -> Dict: + return cast( + Dict, + PrimalSolutionComponent._extract( + instance, + sample, + instance.model_features["Variables"], + extract_y=False, + ), + ) + + @staticmethod + def _extract( + instance: Any, + sample: TrainingSample, + variables: Dict, + extract_y: bool, + ) -> Union[Dict, Tuple[Dict, Dict]]: x: Dict = {} - for (var, var_dict) in instance.model_features["Variables"].items(): - for idx in var_dict.keys(): + y: Dict = {} + for (var, var_dict) in variables.items(): + for (idx, opt_value) in var_dict.items(): + if extract_y: + assert opt_value is not None + assert 0.0 - 1e-5 <= opt_value <= 1.0 + 1e-5, ( + f"Variable {var} has non-binary value {opt_value} in the " + "optimal solution. Predicting values of non-binary " + "variables is not currently supported. Please set its " + "category to None." + ) category = instance.get_variable_category(var, idx) if category is None: continue if category not in x.keys(): x[category] = [] + y[category] = [] features: Any = instance.get_variable_features(var, idx) assert isinstance(features, list) if "LP solution" in sample and sample["LP solution"] is not None: @@ -258,6 +273,9 @@ class PrimalSolutionComponent(Component): if lp_value is not None: features += [sample["LP solution"][var][idx]] x[category] += [features] - for category in x.keys(): - x[category] = np.array(x[category]) - return x + if extract_y: + y[category] += [[opt_value < 0.5, opt_value >= 0.5]] + if extract_y: + return x, y + else: + return x