Primal: Compute statistics

master
Alinson S. Xavier 5 years ago
parent b0b013dd0a
commit 203afc6993

@ -3,7 +3,7 @@
# Released under the modified BSD license. See COPYING.md for more details.
import logging
from typing import Union, Dict, Callable, List, Hashable, Optional
from typing import Union, Dict, Callable, List, Hashable, Optional, Any, TYPE_CHECKING
import numpy as np
from tqdm.auto import tqdm
@ -15,10 +15,13 @@ from miplearn.components import classifier_evaluation_dict
from miplearn.components.component import Component
from miplearn.extractors import InstanceIterator
from miplearn.instance import Instance
from miplearn.types import TrainingSample, VarIndex, Solution
from miplearn.types import TrainingSample, VarIndex, Solution, LearningSolveStats
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from miplearn.solvers.learning import LearningSolver
class PrimalSolutionComponent(Component):
"""
@ -44,11 +47,31 @@ class PrimalSolutionComponent(Component):
self.thresholds: Dict[Hashable, Threshold] = {}
self.threshold_factory = threshold
self.classifier_factory = classifier
self.stats: Dict[str, float] = {}
self._n_free = 0
self._n_zero = 0
self._n_one = 0
def before_solve(self, solver, instance, model):
if len(self.thresholds) > 0:
logger.info("Predicting primal solution...")
solution = self.predict(instance)
# Collect prediction statistics
self._n_free = 0
self._n_zero = 0
self._n_one = 0
for (var, var_dict) in solution.items():
for (idx, value) in var_dict.items():
if value is None:
self._n_free += 1
else:
if value < 0.5:
self._n_zero += 1
else:
self._n_one += 1
# Provide solution to the solver
if self.mode == "heuristic":
solver.internal_solver.fix(solution)
else:
@ -56,13 +79,15 @@ class PrimalSolutionComponent(Component):
def after_solve(
self,
solver,
instance,
model,
stats,
training_data,
):
pass
solver: "LearningSolver",
instance: Instance,
model: Any,
stats: LearningSolveStats,
training_data: TrainingSample,
) -> None:
stats["Primal: free"] = self._n_free
stats["Primal: zero"] = self._n_zero
stats["Primal: one"] = self._n_one
def x(
self,

@ -62,6 +62,9 @@ LearningSolveStats = TypedDict(
"Upper bound": Optional[float],
"Wallclock time": float,
"Warm start value": Optional[float],
"Primal: free": int,
"Primal: zero": int,
"Primal: one": int,
},
total=False,
)

@ -29,7 +29,7 @@ def test_benchmark():
benchmark = BenchmarkRunner(test_solvers)
benchmark.fit(train_instances)
benchmark.parallel_solve(test_instances, n_jobs=2, n_trials=2)
assert benchmark.results.values.shape == (12, 14)
assert benchmark.results.values.shape == (12, 17)
benchmark.write_csv("/tmp/benchmark.csv")
assert os.path.isfile("/tmp/benchmark.csv")

Loading…
Cancel
Save