Primal: Compute statistics

This commit is contained in:
2021-01-25 16:02:40 -06:00
parent b0b013dd0a
commit 203afc6993
3 changed files with 38 additions and 10 deletions

View File

@@ -3,7 +3,7 @@
# Released under the modified BSD license. See COPYING.md for more details. # Released under the modified BSD license. See COPYING.md for more details.
import logging import logging
from typing import Union, Dict, Callable, List, Hashable, Optional from typing import Union, Dict, Callable, List, Hashable, Optional, Any, TYPE_CHECKING
import numpy as np import numpy as np
from tqdm.auto import tqdm from tqdm.auto import tqdm
@@ -15,10 +15,13 @@ from miplearn.components import classifier_evaluation_dict
from miplearn.components.component import Component from miplearn.components.component import Component
from miplearn.extractors import InstanceIterator from miplearn.extractors import InstanceIterator
from miplearn.instance import Instance from miplearn.instance import Instance
from miplearn.types import TrainingSample, VarIndex, Solution from miplearn.types import TrainingSample, VarIndex, Solution, LearningSolveStats
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from miplearn.solvers.learning import LearningSolver
class PrimalSolutionComponent(Component): class PrimalSolutionComponent(Component):
""" """
@@ -44,11 +47,31 @@ class PrimalSolutionComponent(Component):
self.thresholds: Dict[Hashable, Threshold] = {} self.thresholds: Dict[Hashable, Threshold] = {}
self.threshold_factory = threshold self.threshold_factory = threshold
self.classifier_factory = classifier self.classifier_factory = classifier
self.stats: Dict[str, float] = {}
self._n_free = 0
self._n_zero = 0
self._n_one = 0
def before_solve(self, solver, instance, model): def before_solve(self, solver, instance, model):
if len(self.thresholds) > 0: if len(self.thresholds) > 0:
logger.info("Predicting primal solution...") logger.info("Predicting primal solution...")
solution = self.predict(instance) solution = self.predict(instance)
# Collect prediction statistics
self._n_free = 0
self._n_zero = 0
self._n_one = 0
for (var, var_dict) in solution.items():
for (idx, value) in var_dict.items():
if value is None:
self._n_free += 1
else:
if value < 0.5:
self._n_zero += 1
else:
self._n_one += 1
# Provide solution to the solver
if self.mode == "heuristic": if self.mode == "heuristic":
solver.internal_solver.fix(solution) solver.internal_solver.fix(solution)
else: else:
@@ -56,13 +79,15 @@ class PrimalSolutionComponent(Component):
def after_solve( def after_solve(
self, self,
solver, solver: "LearningSolver",
instance, instance: Instance,
model, model: Any,
stats, stats: LearningSolveStats,
training_data, training_data: TrainingSample,
): ) -> None:
pass stats["Primal: free"] = self._n_free
stats["Primal: zero"] = self._n_zero
stats["Primal: one"] = self._n_one
def x( def x(
self, self,

View File

@@ -62,6 +62,9 @@ LearningSolveStats = TypedDict(
"Upper bound": Optional[float], "Upper bound": Optional[float],
"Wallclock time": float, "Wallclock time": float,
"Warm start value": Optional[float], "Warm start value": Optional[float],
"Primal: free": int,
"Primal: zero": int,
"Primal: one": int,
}, },
total=False, total=False,
) )

View File

@@ -29,7 +29,7 @@ def test_benchmark():
benchmark = BenchmarkRunner(test_solvers) benchmark = BenchmarkRunner(test_solvers)
benchmark.fit(train_instances) benchmark.fit(train_instances)
benchmark.parallel_solve(test_instances, n_jobs=2, n_trials=2) benchmark.parallel_solve(test_instances, n_jobs=2, n_trials=2)
assert benchmark.results.values.shape == (12, 14) assert benchmark.results.values.shape == (12, 17)
benchmark.write_csv("/tmp/benchmark.csv") benchmark.write_csv("/tmp/benchmark.csv")
assert os.path.isfile("/tmp/benchmark.csv") assert os.path.isfile("/tmp/benchmark.csv")