diff --git a/miplearn/components/primal.py b/miplearn/components/primal.py index db7ea55..55db4d6 100644 --- a/miplearn/components/primal.py +++ b/miplearn/components/primal.py @@ -3,7 +3,7 @@ # Released under the modified BSD license. See COPYING.md for more details. import logging -from typing import Union, Dict, Callable, List, Hashable, Optional +from typing import Union, Dict, Callable, List, Hashable, Optional, Any, TYPE_CHECKING import numpy as np from tqdm.auto import tqdm @@ -15,10 +15,13 @@ from miplearn.components import classifier_evaluation_dict from miplearn.components.component import Component from miplearn.extractors import InstanceIterator from miplearn.instance import Instance -from miplearn.types import TrainingSample, VarIndex, Solution +from miplearn.types import TrainingSample, VarIndex, Solution, LearningSolveStats logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from miplearn.solvers.learning import LearningSolver + class PrimalSolutionComponent(Component): """ @@ -44,11 +47,31 @@ class PrimalSolutionComponent(Component): self.thresholds: Dict[Hashable, Threshold] = {} self.threshold_factory = threshold self.classifier_factory = classifier + self.stats: Dict[str, float] = {} + self._n_free = 0 + self._n_zero = 0 + self._n_one = 0 def before_solve(self, solver, instance, model): if len(self.thresholds) > 0: logger.info("Predicting primal solution...") solution = self.predict(instance) + + # Collect prediction statistics + self._n_free = 0 + self._n_zero = 0 + self._n_one = 0 + for (var, var_dict) in solution.items(): + for (idx, value) in var_dict.items(): + if value is None: + self._n_free += 1 + else: + if value < 0.5: + self._n_zero += 1 + else: + self._n_one += 1 + + # Provide solution to the solver if self.mode == "heuristic": solver.internal_solver.fix(solution) else: @@ -56,13 +79,15 @@ class PrimalSolutionComponent(Component): def after_solve( self, - solver, - instance, - model, - stats, - training_data, - ): - pass + solver: "LearningSolver", + instance: Instance, + model: Any, + stats: LearningSolveStats, + training_data: TrainingSample, + ) -> None: + stats["Primal: free"] = self._n_free + stats["Primal: zero"] = self._n_zero + stats["Primal: one"] = self._n_one def x( self, diff --git a/miplearn/types.py b/miplearn/types.py index e7ec102..7978a73 100644 --- a/miplearn/types.py +++ b/miplearn/types.py @@ -62,6 +62,9 @@ LearningSolveStats = TypedDict( "Upper bound": Optional[float], "Wallclock time": float, "Warm start value": Optional[float], + "Primal: free": int, + "Primal: zero": int, + "Primal: one": int, }, total=False, ) diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index e2950c8..7be39ec 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -29,7 +29,7 @@ def test_benchmark(): benchmark = BenchmarkRunner(test_solvers) benchmark.fit(train_instances) benchmark.parallel_solve(test_instances, n_jobs=2, n_trials=2) - assert benchmark.results.values.shape == (12, 14) + assert benchmark.results.values.shape == (12, 17) benchmark.write_csv("/tmp/benchmark.csv") assert os.path.isfile("/tmp/benchmark.csv")