mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 17:38:51 -06:00
Convert LPSolveStats into dataclass
This commit is contained in:
@@ -147,10 +147,10 @@ class GurobiSolver(InternalSolver):
|
|||||||
opt_value = None
|
opt_value = None
|
||||||
if not self.is_infeasible():
|
if not self.is_infeasible():
|
||||||
opt_value = self.model.objVal
|
opt_value = self.model.objVal
|
||||||
return {
|
return LPSolveStats(
|
||||||
"LP value": opt_value,
|
lp_value=opt_value,
|
||||||
"LP log": log,
|
lp_log=log,
|
||||||
}
|
)
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def solve(
|
def solve(
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from overrides import EnforceOverrides
|
from overrides import EnforceOverrides
|
||||||
@@ -11,7 +12,6 @@ from overrides import EnforceOverrides
|
|||||||
from miplearn.features import Constraint, Variable
|
from miplearn.features import Constraint, Variable
|
||||||
from miplearn.instance.base import Instance
|
from miplearn.instance.base import Instance
|
||||||
from miplearn.types import (
|
from miplearn.types import (
|
||||||
LPSolveStats,
|
|
||||||
IterationCallback,
|
IterationCallback,
|
||||||
LazyCallback,
|
LazyCallback,
|
||||||
MIPSolveStats,
|
MIPSolveStats,
|
||||||
@@ -24,6 +24,12 @@ from miplearn.types import (
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LPSolveStats:
|
||||||
|
lp_log: Optional[str] = None
|
||||||
|
lp_value: Optional[float] = None
|
||||||
|
|
||||||
|
|
||||||
class InternalSolver(ABC, EnforceOverrides):
|
class InternalSolver(ABC, EnforceOverrides):
|
||||||
"""
|
"""
|
||||||
Abstract class representing the MIP solver used internally by LearningSolver.
|
Abstract class representing the MIP solver used internally by LearningSolver.
|
||||||
|
|||||||
@@ -178,10 +178,10 @@ class LearningSolver:
|
|||||||
|
|
||||||
logger.info("Solving root LP relaxation...")
|
logger.info("Solving root LP relaxation...")
|
||||||
lp_stats = self.internal_solver.solve_lp(tee=tee)
|
lp_stats = self.internal_solver.solve_lp(tee=tee)
|
||||||
stats.update(cast(LearningSolveStats, lp_stats))
|
stats.update(cast(LearningSolveStats, lp_stats.__dict__))
|
||||||
training_sample.lp_solution = self.internal_solver.get_solution()
|
training_sample.lp_solution = self.internal_solver.get_solution()
|
||||||
training_sample.lp_value = lp_stats["LP value"]
|
training_sample.lp_value = lp_stats.lp_value
|
||||||
training_sample.lp_log = lp_stats["LP log"]
|
training_sample.lp_log = lp_stats.lp_log
|
||||||
|
|
||||||
logger.debug("Running after_solve_lp callbacks...")
|
logger.debug("Running after_solve_lp callbacks...")
|
||||||
for component in self.components.values():
|
for component in self.components.values():
|
||||||
@@ -240,8 +240,6 @@ class LearningSolver:
|
|||||||
lazy_cb=lazy_cb,
|
lazy_cb=lazy_cb,
|
||||||
)
|
)
|
||||||
stats.update(cast(LearningSolveStats, mip_stats))
|
stats.update(cast(LearningSolveStats, mip_stats))
|
||||||
if training_sample.lp_value is not None:
|
|
||||||
stats["LP value"] = training_sample.lp_value
|
|
||||||
stats["Solver"] = "default"
|
stats["Solver"] = "default"
|
||||||
stats["Gap"] = self._compute_gap(
|
stats["Gap"] = self._compute_gap(
|
||||||
ub=stats["Upper bound"],
|
ub=stats["Upper bound"],
|
||||||
|
|||||||
@@ -87,10 +87,10 @@ class BasePyomoSolver(InternalSolver):
|
|||||||
if not self.is_infeasible():
|
if not self.is_infeasible():
|
||||||
opt_value = results["Problem"][0]["Lower bound"]
|
opt_value = results["Problem"][0]["Lower bound"]
|
||||||
self._has_lp_solution = True
|
self._has_lp_solution = True
|
||||||
return {
|
return LPSolveStats(
|
||||||
"LP value": opt_value,
|
lp_value=opt_value,
|
||||||
"LP log": streams[0].getvalue(),
|
lp_log=streams[0].getvalue(),
|
||||||
}
|
)
|
||||||
|
|
||||||
def _restore_integrality(self) -> None:
|
def _restore_integrality(self) -> None:
|
||||||
for var in self._bin_vars:
|
for var in self._bin_vars:
|
||||||
|
|||||||
@@ -137,9 +137,10 @@ def run_basic_usage_tests(solver: InternalSolver) -> None:
|
|||||||
# Solve linear programming relaxation
|
# Solve linear programming relaxation
|
||||||
lp_stats = solver.solve_lp()
|
lp_stats = solver.solve_lp()
|
||||||
assert not solver.is_infeasible()
|
assert not solver.is_infeasible()
|
||||||
assert lp_stats["LP value"] is not None
|
assert lp_stats.lp_value is not None
|
||||||
assert_equals(round(lp_stats["LP value"], 3), 1287.923)
|
assert lp_stats.lp_log is not None
|
||||||
assert len(lp_stats["LP log"]) > 100
|
assert_equals(round(lp_stats.lp_value, 3), 1287.923)
|
||||||
|
assert len(lp_stats.lp_log) > 100
|
||||||
|
|
||||||
# Fetch variables (after-load)
|
# Fetch variables (after-load)
|
||||||
assert_equals(
|
assert_equals(
|
||||||
@@ -362,7 +363,7 @@ def run_infeasibility_tests(solver: InternalSolver) -> None:
|
|||||||
assert mip_stats["Lower bound"] is None
|
assert mip_stats["Lower bound"] is None
|
||||||
lp_stats = solver.solve_lp()
|
lp_stats = solver.solve_lp()
|
||||||
assert solver.get_solution() is None
|
assert solver.get_solution() is None
|
||||||
assert lp_stats["LP value"] is None
|
assert lp_stats.lp_value is None
|
||||||
|
|
||||||
|
|
||||||
def run_iteration_cb_tests(solver: InternalSolver) -> None:
|
def run_iteration_cb_tests(solver: InternalSolver) -> None:
|
||||||
|
|||||||
@@ -19,14 +19,6 @@ UserCutCallback = Callable[["InternalSolver", Any], None]
|
|||||||
VariableName = str
|
VariableName = str
|
||||||
Solution = Dict[VariableName, Optional[float]]
|
Solution = Dict[VariableName, Optional[float]]
|
||||||
|
|
||||||
LPSolveStats = TypedDict(
|
|
||||||
"LPSolveStats",
|
|
||||||
{
|
|
||||||
"LP log": str,
|
|
||||||
"LP value": Optional[float],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
MIPSolveStats = TypedDict(
|
MIPSolveStats = TypedDict(
|
||||||
"MIPSolveStats",
|
"MIPSolveStats",
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -34,7 +34,6 @@ def test_benchmark() -> None:
|
|||||||
n_jobs=n_jobs,
|
n_jobs=n_jobs,
|
||||||
n_trials=2,
|
n_trials=2,
|
||||||
)
|
)
|
||||||
assert benchmark.results.values.shape == (12, 20)
|
|
||||||
|
|
||||||
benchmark.write_csv("/tmp/benchmark.csv")
|
benchmark.write_csv("/tmp/benchmark.csv")
|
||||||
assert os.path.isfile("/tmp/benchmark.csv")
|
assert os.path.isfile("/tmp/benchmark.csv")
|
||||||
|
assert benchmark.results.values.shape == (12, 20)
|
||||||
|
|||||||
Reference in New Issue
Block a user