Convert LPSolveStats into dataclass

This commit is contained in:
2021-04-11 08:41:50 -05:00
parent 6afdf2ed55
commit 945f6a091c
7 changed files with 24 additions and 28 deletions

View File

@@ -147,10 +147,10 @@ class GurobiSolver(InternalSolver):
opt_value = None opt_value = None
if not self.is_infeasible(): if not self.is_infeasible():
opt_value = self.model.objVal opt_value = self.model.objVal
return { return LPSolveStats(
"LP value": opt_value, lp_value=opt_value,
"LP log": log, lp_log=log,
} )
@overrides @overrides
def solve( def solve(

View File

@@ -4,6 +4,7 @@
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from overrides import EnforceOverrides from overrides import EnforceOverrides
@@ -11,7 +12,6 @@ from overrides import EnforceOverrides
from miplearn.features import Constraint, Variable from miplearn.features import Constraint, Variable
from miplearn.instance.base import Instance from miplearn.instance.base import Instance
from miplearn.types import ( from miplearn.types import (
LPSolveStats,
IterationCallback, IterationCallback,
LazyCallback, LazyCallback,
MIPSolveStats, MIPSolveStats,
@@ -24,6 +24,12 @@ from miplearn.types import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@dataclass
class LPSolveStats:
lp_log: Optional[str] = None
lp_value: Optional[float] = None
class InternalSolver(ABC, EnforceOverrides): class InternalSolver(ABC, EnforceOverrides):
""" """
Abstract class representing the MIP solver used internally by LearningSolver. Abstract class representing the MIP solver used internally by LearningSolver.

View File

@@ -178,10 +178,10 @@ class LearningSolver:
logger.info("Solving root LP relaxation...") logger.info("Solving root LP relaxation...")
lp_stats = self.internal_solver.solve_lp(tee=tee) lp_stats = self.internal_solver.solve_lp(tee=tee)
stats.update(cast(LearningSolveStats, lp_stats)) stats.update(cast(LearningSolveStats, lp_stats.__dict__))
training_sample.lp_solution = self.internal_solver.get_solution() training_sample.lp_solution = self.internal_solver.get_solution()
training_sample.lp_value = lp_stats["LP value"] training_sample.lp_value = lp_stats.lp_value
training_sample.lp_log = lp_stats["LP log"] training_sample.lp_log = lp_stats.lp_log
logger.debug("Running after_solve_lp callbacks...") logger.debug("Running after_solve_lp callbacks...")
for component in self.components.values(): for component in self.components.values():
@@ -240,8 +240,6 @@ class LearningSolver:
lazy_cb=lazy_cb, lazy_cb=lazy_cb,
) )
stats.update(cast(LearningSolveStats, mip_stats)) stats.update(cast(LearningSolveStats, mip_stats))
if training_sample.lp_value is not None:
stats["LP value"] = training_sample.lp_value
stats["Solver"] = "default" stats["Solver"] = "default"
stats["Gap"] = self._compute_gap( stats["Gap"] = self._compute_gap(
ub=stats["Upper bound"], ub=stats["Upper bound"],

View File

@@ -87,10 +87,10 @@ class BasePyomoSolver(InternalSolver):
if not self.is_infeasible(): if not self.is_infeasible():
opt_value = results["Problem"][0]["Lower bound"] opt_value = results["Problem"][0]["Lower bound"]
self._has_lp_solution = True self._has_lp_solution = True
return { return LPSolveStats(
"LP value": opt_value, lp_value=opt_value,
"LP log": streams[0].getvalue(), lp_log=streams[0].getvalue(),
} )
def _restore_integrality(self) -> None: def _restore_integrality(self) -> None:
for var in self._bin_vars: for var in self._bin_vars:

View File

@@ -137,9 +137,10 @@ def run_basic_usage_tests(solver: InternalSolver) -> None:
# Solve linear programming relaxation # Solve linear programming relaxation
lp_stats = solver.solve_lp() lp_stats = solver.solve_lp()
assert not solver.is_infeasible() assert not solver.is_infeasible()
assert lp_stats["LP value"] is not None assert lp_stats.lp_value is not None
assert_equals(round(lp_stats["LP value"], 3), 1287.923) assert lp_stats.lp_log is not None
assert len(lp_stats["LP log"]) > 100 assert_equals(round(lp_stats.lp_value, 3), 1287.923)
assert len(lp_stats.lp_log) > 100
# Fetch variables (after-load) # Fetch variables (after-load)
assert_equals( assert_equals(
@@ -362,7 +363,7 @@ def run_infeasibility_tests(solver: InternalSolver) -> None:
assert mip_stats["Lower bound"] is None assert mip_stats["Lower bound"] is None
lp_stats = solver.solve_lp() lp_stats = solver.solve_lp()
assert solver.get_solution() is None assert solver.get_solution() is None
assert lp_stats["LP value"] is None assert lp_stats.lp_value is None
def run_iteration_cb_tests(solver: InternalSolver) -> None: def run_iteration_cb_tests(solver: InternalSolver) -> None:

View File

@@ -19,14 +19,6 @@ UserCutCallback = Callable[["InternalSolver", Any], None]
VariableName = str VariableName = str
Solution = Dict[VariableName, Optional[float]] Solution = Dict[VariableName, Optional[float]]
LPSolveStats = TypedDict(
"LPSolveStats",
{
"LP log": str,
"LP value": Optional[float],
},
)
MIPSolveStats = TypedDict( MIPSolveStats = TypedDict(
"MIPSolveStats", "MIPSolveStats",
{ {

View File

@@ -34,7 +34,6 @@ def test_benchmark() -> None:
n_jobs=n_jobs, n_jobs=n_jobs,
n_trials=2, n_trials=2,
) )
assert benchmark.results.values.shape == (12, 20)
benchmark.write_csv("/tmp/benchmark.csv") benchmark.write_csv("/tmp/benchmark.csv")
assert os.path.isfile("/tmp/benchmark.csv") assert os.path.isfile("/tmp/benchmark.csv")
assert benchmark.results.values.shape == (12, 20)