Convert LPSolveStats into dataclass

master
Alinson S. Xavier 5 years ago
parent 6afdf2ed55
commit 945f6a091c
No known key found for this signature in database
GPG Key ID: DCA0DAD4D2F58624

@ -147,10 +147,10 @@ class GurobiSolver(InternalSolver):
opt_value = None opt_value = None
if not self.is_infeasible(): if not self.is_infeasible():
opt_value = self.model.objVal opt_value = self.model.objVal
return { return LPSolveStats(
"LP value": opt_value, lp_value=opt_value,
"LP log": log, lp_log=log,
} )
@overrides @overrides
def solve( def solve(

@ -4,6 +4,7 @@
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from overrides import EnforceOverrides from overrides import EnforceOverrides
@ -11,7 +12,6 @@ from overrides import EnforceOverrides
from miplearn.features import Constraint, Variable from miplearn.features import Constraint, Variable
from miplearn.instance.base import Instance from miplearn.instance.base import Instance
from miplearn.types import ( from miplearn.types import (
LPSolveStats,
IterationCallback, IterationCallback,
LazyCallback, LazyCallback,
MIPSolveStats, MIPSolveStats,
@ -24,6 +24,12 @@ from miplearn.types import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@dataclass
class LPSolveStats:
lp_log: Optional[str] = None
lp_value: Optional[float] = None
class InternalSolver(ABC, EnforceOverrides): class InternalSolver(ABC, EnforceOverrides):
""" """
Abstract class representing the MIP solver used internally by LearningSolver. Abstract class representing the MIP solver used internally by LearningSolver.

@ -178,10 +178,10 @@ class LearningSolver:
logger.info("Solving root LP relaxation...") logger.info("Solving root LP relaxation...")
lp_stats = self.internal_solver.solve_lp(tee=tee) lp_stats = self.internal_solver.solve_lp(tee=tee)
stats.update(cast(LearningSolveStats, lp_stats)) stats.update(cast(LearningSolveStats, lp_stats.__dict__))
training_sample.lp_solution = self.internal_solver.get_solution() training_sample.lp_solution = self.internal_solver.get_solution()
training_sample.lp_value = lp_stats["LP value"] training_sample.lp_value = lp_stats.lp_value
training_sample.lp_log = lp_stats["LP log"] training_sample.lp_log = lp_stats.lp_log
logger.debug("Running after_solve_lp callbacks...") logger.debug("Running after_solve_lp callbacks...")
for component in self.components.values(): for component in self.components.values():
@ -240,8 +240,6 @@ class LearningSolver:
lazy_cb=lazy_cb, lazy_cb=lazy_cb,
) )
stats.update(cast(LearningSolveStats, mip_stats)) stats.update(cast(LearningSolveStats, mip_stats))
if training_sample.lp_value is not None:
stats["LP value"] = training_sample.lp_value
stats["Solver"] = "default" stats["Solver"] = "default"
stats["Gap"] = self._compute_gap( stats["Gap"] = self._compute_gap(
ub=stats["Upper bound"], ub=stats["Upper bound"],

@ -87,10 +87,10 @@ class BasePyomoSolver(InternalSolver):
if not self.is_infeasible(): if not self.is_infeasible():
opt_value = results["Problem"][0]["Lower bound"] opt_value = results["Problem"][0]["Lower bound"]
self._has_lp_solution = True self._has_lp_solution = True
return { return LPSolveStats(
"LP value": opt_value, lp_value=opt_value,
"LP log": streams[0].getvalue(), lp_log=streams[0].getvalue(),
} )
def _restore_integrality(self) -> None: def _restore_integrality(self) -> None:
for var in self._bin_vars: for var in self._bin_vars:

@ -137,9 +137,10 @@ def run_basic_usage_tests(solver: InternalSolver) -> None:
# Solve linear programming relaxation # Solve linear programming relaxation
lp_stats = solver.solve_lp() lp_stats = solver.solve_lp()
assert not solver.is_infeasible() assert not solver.is_infeasible()
assert lp_stats["LP value"] is not None assert lp_stats.lp_value is not None
assert_equals(round(lp_stats["LP value"], 3), 1287.923) assert lp_stats.lp_log is not None
assert len(lp_stats["LP log"]) > 100 assert_equals(round(lp_stats.lp_value, 3), 1287.923)
assert len(lp_stats.lp_log) > 100
# Fetch variables (after-load) # Fetch variables (after-load)
assert_equals( assert_equals(
@ -362,7 +363,7 @@ def run_infeasibility_tests(solver: InternalSolver) -> None:
assert mip_stats["Lower bound"] is None assert mip_stats["Lower bound"] is None
lp_stats = solver.solve_lp() lp_stats = solver.solve_lp()
assert solver.get_solution() is None assert solver.get_solution() is None
assert lp_stats["LP value"] is None assert lp_stats.lp_value is None
def run_iteration_cb_tests(solver: InternalSolver) -> None: def run_iteration_cb_tests(solver: InternalSolver) -> None:

@ -19,14 +19,6 @@ UserCutCallback = Callable[["InternalSolver", Any], None]
VariableName = str VariableName = str
Solution = Dict[VariableName, Optional[float]] Solution = Dict[VariableName, Optional[float]]
LPSolveStats = TypedDict(
"LPSolveStats",
{
"LP log": str,
"LP value": Optional[float],
},
)
MIPSolveStats = TypedDict( MIPSolveStats = TypedDict(
"MIPSolveStats", "MIPSolveStats",
{ {

@ -34,7 +34,6 @@ def test_benchmark() -> None:
n_jobs=n_jobs, n_jobs=n_jobs,
n_trials=2, n_trials=2,
) )
assert benchmark.results.values.shape == (12, 20)
benchmark.write_csv("/tmp/benchmark.csv") benchmark.write_csv("/tmp/benchmark.csv")
assert os.path.isfile("/tmp/benchmark.csv") assert os.path.isfile("/tmp/benchmark.csv")
assert benchmark.results.values.shape == (12, 20)

Loading…
Cancel
Save