mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Convert LPSolveStats into dataclass
This commit is contained in:
@@ -147,10 +147,10 @@ class GurobiSolver(InternalSolver):
|
||||
opt_value = None
|
||||
if not self.is_infeasible():
|
||||
opt_value = self.model.objVal
|
||||
return {
|
||||
"LP value": opt_value,
|
||||
"LP log": log,
|
||||
}
|
||||
return LPSolveStats(
|
||||
lp_value=opt_value,
|
||||
lp_log=log,
|
||||
)
|
||||
|
||||
@overrides
|
||||
def solve(
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from overrides import EnforceOverrides
|
||||
@@ -11,7 +12,6 @@ from overrides import EnforceOverrides
|
||||
from miplearn.features import Constraint, Variable
|
||||
from miplearn.instance.base import Instance
|
||||
from miplearn.types import (
|
||||
LPSolveStats,
|
||||
IterationCallback,
|
||||
LazyCallback,
|
||||
MIPSolveStats,
|
||||
@@ -24,6 +24,12 @@ from miplearn.types import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LPSolveStats:
|
||||
lp_log: Optional[str] = None
|
||||
lp_value: Optional[float] = None
|
||||
|
||||
|
||||
class InternalSolver(ABC, EnforceOverrides):
|
||||
"""
|
||||
Abstract class representing the MIP solver used internally by LearningSolver.
|
||||
|
||||
@@ -178,10 +178,10 @@ class LearningSolver:
|
||||
|
||||
logger.info("Solving root LP relaxation...")
|
||||
lp_stats = self.internal_solver.solve_lp(tee=tee)
|
||||
stats.update(cast(LearningSolveStats, lp_stats))
|
||||
stats.update(cast(LearningSolveStats, lp_stats.__dict__))
|
||||
training_sample.lp_solution = self.internal_solver.get_solution()
|
||||
training_sample.lp_value = lp_stats["LP value"]
|
||||
training_sample.lp_log = lp_stats["LP log"]
|
||||
training_sample.lp_value = lp_stats.lp_value
|
||||
training_sample.lp_log = lp_stats.lp_log
|
||||
|
||||
logger.debug("Running after_solve_lp callbacks...")
|
||||
for component in self.components.values():
|
||||
@@ -240,8 +240,6 @@ class LearningSolver:
|
||||
lazy_cb=lazy_cb,
|
||||
)
|
||||
stats.update(cast(LearningSolveStats, mip_stats))
|
||||
if training_sample.lp_value is not None:
|
||||
stats["LP value"] = training_sample.lp_value
|
||||
stats["Solver"] = "default"
|
||||
stats["Gap"] = self._compute_gap(
|
||||
ub=stats["Upper bound"],
|
||||
|
||||
@@ -87,10 +87,10 @@ class BasePyomoSolver(InternalSolver):
|
||||
if not self.is_infeasible():
|
||||
opt_value = results["Problem"][0]["Lower bound"]
|
||||
self._has_lp_solution = True
|
||||
return {
|
||||
"LP value": opt_value,
|
||||
"LP log": streams[0].getvalue(),
|
||||
}
|
||||
return LPSolveStats(
|
||||
lp_value=opt_value,
|
||||
lp_log=streams[0].getvalue(),
|
||||
)
|
||||
|
||||
def _restore_integrality(self) -> None:
|
||||
for var in self._bin_vars:
|
||||
|
||||
@@ -137,9 +137,10 @@ def run_basic_usage_tests(solver: InternalSolver) -> None:
|
||||
# Solve linear programming relaxation
|
||||
lp_stats = solver.solve_lp()
|
||||
assert not solver.is_infeasible()
|
||||
assert lp_stats["LP value"] is not None
|
||||
assert_equals(round(lp_stats["LP value"], 3), 1287.923)
|
||||
assert len(lp_stats["LP log"]) > 100
|
||||
assert lp_stats.lp_value is not None
|
||||
assert lp_stats.lp_log is not None
|
||||
assert_equals(round(lp_stats.lp_value, 3), 1287.923)
|
||||
assert len(lp_stats.lp_log) > 100
|
||||
|
||||
# Fetch variables (after-load)
|
||||
assert_equals(
|
||||
@@ -362,7 +363,7 @@ def run_infeasibility_tests(solver: InternalSolver) -> None:
|
||||
assert mip_stats["Lower bound"] is None
|
||||
lp_stats = solver.solve_lp()
|
||||
assert solver.get_solution() is None
|
||||
assert lp_stats["LP value"] is None
|
||||
assert lp_stats.lp_value is None
|
||||
|
||||
|
||||
def run_iteration_cb_tests(solver: InternalSolver) -> None:
|
||||
|
||||
@@ -19,14 +19,6 @@ UserCutCallback = Callable[["InternalSolver", Any], None]
|
||||
VariableName = str
|
||||
Solution = Dict[VariableName, Optional[float]]
|
||||
|
||||
LPSolveStats = TypedDict(
|
||||
"LPSolveStats",
|
||||
{
|
||||
"LP log": str,
|
||||
"LP value": Optional[float],
|
||||
},
|
||||
)
|
||||
|
||||
MIPSolveStats = TypedDict(
|
||||
"MIPSolveStats",
|
||||
{
|
||||
|
||||
@@ -34,7 +34,6 @@ def test_benchmark() -> None:
|
||||
n_jobs=n_jobs,
|
||||
n_trials=2,
|
||||
)
|
||||
assert benchmark.results.values.shape == (12, 20)
|
||||
|
||||
benchmark.write_csv("/tmp/benchmark.csv")
|
||||
assert os.path.isfile("/tmp/benchmark.csv")
|
||||
assert benchmark.results.values.shape == (12, 20)
|
||||
|
||||
Reference in New Issue
Block a user