Convert LPSolveStats into dataclass

master
Alinson S. Xavier 5 years ago
parent 6afdf2ed55
commit 945f6a091c
No known key found for this signature in database
GPG Key ID: DCA0DAD4D2F58624

@ -147,10 +147,10 @@ class GurobiSolver(InternalSolver):
opt_value = None
if not self.is_infeasible():
opt_value = self.model.objVal
return {
"LP value": opt_value,
"LP log": log,
}
return LPSolveStats(
lp_value=opt_value,
lp_log=log,
)
@overrides
def solve(

@ -4,6 +4,7 @@
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from overrides import EnforceOverrides
@ -11,7 +12,6 @@ from overrides import EnforceOverrides
from miplearn.features import Constraint, Variable
from miplearn.instance.base import Instance
from miplearn.types import (
LPSolveStats,
IterationCallback,
LazyCallback,
MIPSolveStats,
@ -24,6 +24,12 @@ from miplearn.types import (
logger = logging.getLogger(__name__)
@dataclass
class LPSolveStats:
lp_log: Optional[str] = None
lp_value: Optional[float] = None
class InternalSolver(ABC, EnforceOverrides):
"""
Abstract class representing the MIP solver used internally by LearningSolver.

@ -178,10 +178,10 @@ class LearningSolver:
logger.info("Solving root LP relaxation...")
lp_stats = self.internal_solver.solve_lp(tee=tee)
stats.update(cast(LearningSolveStats, lp_stats))
stats.update(cast(LearningSolveStats, lp_stats.__dict__))
training_sample.lp_solution = self.internal_solver.get_solution()
training_sample.lp_value = lp_stats["LP value"]
training_sample.lp_log = lp_stats["LP log"]
training_sample.lp_value = lp_stats.lp_value
training_sample.lp_log = lp_stats.lp_log
logger.debug("Running after_solve_lp callbacks...")
for component in self.components.values():
@ -240,8 +240,6 @@ class LearningSolver:
lazy_cb=lazy_cb,
)
stats.update(cast(LearningSolveStats, mip_stats))
if training_sample.lp_value is not None:
stats["LP value"] = training_sample.lp_value
stats["Solver"] = "default"
stats["Gap"] = self._compute_gap(
ub=stats["Upper bound"],

@ -87,10 +87,10 @@ class BasePyomoSolver(InternalSolver):
if not self.is_infeasible():
opt_value = results["Problem"][0]["Lower bound"]
self._has_lp_solution = True
return {
"LP value": opt_value,
"LP log": streams[0].getvalue(),
}
return LPSolveStats(
lp_value=opt_value,
lp_log=streams[0].getvalue(),
)
def _restore_integrality(self) -> None:
for var in self._bin_vars:

@ -137,9 +137,10 @@ def run_basic_usage_tests(solver: InternalSolver) -> None:
# Solve linear programming relaxation
lp_stats = solver.solve_lp()
assert not solver.is_infeasible()
assert lp_stats["LP value"] is not None
assert_equals(round(lp_stats["LP value"], 3), 1287.923)
assert len(lp_stats["LP log"]) > 100
assert lp_stats.lp_value is not None
assert lp_stats.lp_log is not None
assert_equals(round(lp_stats.lp_value, 3), 1287.923)
assert len(lp_stats.lp_log) > 100
# Fetch variables (after-load)
assert_equals(
@ -362,7 +363,7 @@ def run_infeasibility_tests(solver: InternalSolver) -> None:
assert mip_stats["Lower bound"] is None
lp_stats = solver.solve_lp()
assert solver.get_solution() is None
assert lp_stats["LP value"] is None
assert lp_stats.lp_value is None
def run_iteration_cb_tests(solver: InternalSolver) -> None:

@ -19,14 +19,6 @@ UserCutCallback = Callable[["InternalSolver", Any], None]
VariableName = str
Solution = Dict[VariableName, Optional[float]]
LPSolveStats = TypedDict(
"LPSolveStats",
{
"LP log": str,
"LP value": Optional[float],
},
)
MIPSolveStats = TypedDict(
"MIPSolveStats",
{

@ -34,7 +34,6 @@ def test_benchmark() -> None:
n_jobs=n_jobs,
n_trials=2,
)
assert benchmark.results.values.shape == (12, 20)
benchmark.write_csv("/tmp/benchmark.csv")
assert os.path.isfile("/tmp/benchmark.csv")
assert benchmark.results.values.shape == (12, 20)

Loading…
Cancel
Save