Primal: Compute statistics

master
Alinson S. Xavier 5 years ago
parent b0b013dd0a
commit 203afc6993

@ -3,7 +3,7 @@
# Released under the modified BSD license. See COPYING.md for more details. # Released under the modified BSD license. See COPYING.md for more details.
import logging import logging
from typing import Union, Dict, Callable, List, Hashable, Optional from typing import Union, Dict, Callable, List, Hashable, Optional, Any, TYPE_CHECKING
import numpy as np import numpy as np
from tqdm.auto import tqdm from tqdm.auto import tqdm
@ -15,10 +15,13 @@ from miplearn.components import classifier_evaluation_dict
from miplearn.components.component import Component from miplearn.components.component import Component
from miplearn.extractors import InstanceIterator from miplearn.extractors import InstanceIterator
from miplearn.instance import Instance from miplearn.instance import Instance
from miplearn.types import TrainingSample, VarIndex, Solution from miplearn.types import TrainingSample, VarIndex, Solution, LearningSolveStats
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from miplearn.solvers.learning import LearningSolver
class PrimalSolutionComponent(Component): class PrimalSolutionComponent(Component):
""" """
@ -44,11 +47,31 @@ class PrimalSolutionComponent(Component):
self.thresholds: Dict[Hashable, Threshold] = {} self.thresholds: Dict[Hashable, Threshold] = {}
self.threshold_factory = threshold self.threshold_factory = threshold
self.classifier_factory = classifier self.classifier_factory = classifier
self.stats: Dict[str, float] = {}
self._n_free = 0
self._n_zero = 0
self._n_one = 0
def before_solve(self, solver, instance, model): def before_solve(self, solver, instance, model):
if len(self.thresholds) > 0: if len(self.thresholds) > 0:
logger.info("Predicting primal solution...") logger.info("Predicting primal solution...")
solution = self.predict(instance) solution = self.predict(instance)
# Collect prediction statistics
self._n_free = 0
self._n_zero = 0
self._n_one = 0
for (var, var_dict) in solution.items():
for (idx, value) in var_dict.items():
if value is None:
self._n_free += 1
else:
if value < 0.5:
self._n_zero += 1
else:
self._n_one += 1
# Provide solution to the solver
if self.mode == "heuristic": if self.mode == "heuristic":
solver.internal_solver.fix(solution) solver.internal_solver.fix(solution)
else: else:
@ -56,13 +79,15 @@ class PrimalSolutionComponent(Component):
def after_solve( def after_solve(
self, self,
solver, solver: "LearningSolver",
instance, instance: Instance,
model, model: Any,
stats, stats: LearningSolveStats,
training_data, training_data: TrainingSample,
): ) -> None:
pass stats["Primal: free"] = self._n_free
stats["Primal: zero"] = self._n_zero
stats["Primal: one"] = self._n_one
def x( def x(
self, self,

@ -62,6 +62,9 @@ LearningSolveStats = TypedDict(
"Upper bound": Optional[float], "Upper bound": Optional[float],
"Wallclock time": float, "Wallclock time": float,
"Warm start value": Optional[float], "Warm start value": Optional[float],
"Primal: free": int,
"Primal: zero": int,
"Primal: one": int,
}, },
total=False, total=False,
) )

@ -29,7 +29,7 @@ def test_benchmark():
benchmark = BenchmarkRunner(test_solvers) benchmark = BenchmarkRunner(test_solvers)
benchmark.fit(train_instances) benchmark.fit(train_instances)
benchmark.parallel_solve(test_instances, n_jobs=2, n_trials=2) benchmark.parallel_solve(test_instances, n_jobs=2, n_trials=2)
assert benchmark.results.values.shape == (12, 14) assert benchmark.results.values.shape == (12, 17)
benchmark.write_csv("/tmp/benchmark.csv") benchmark.write_csv("/tmp/benchmark.csv")
assert os.path.isfile("/tmp/benchmark.csv") assert os.path.isfile("/tmp/benchmark.csv")

Loading…
Cancel
Save