From 945f6a091cdce5bb2bc815f8a4b67b6bad839241 Mon Sep 17 00:00:00 2001 From: "Alinson S. Xavier" Date: Sun, 11 Apr 2021 08:41:50 -0500 Subject: [PATCH] Convert LPSolveStats into dataclass --- miplearn/solvers/gurobi.py | 8 ++++---- miplearn/solvers/internal.py | 8 +++++++- miplearn/solvers/learning.py | 8 +++----- miplearn/solvers/pyomo/base.py | 8 ++++---- miplearn/solvers/tests/__init__.py | 9 +++++---- miplearn/types.py | 8 -------- tests/test_benchmark.py | 3 +-- 7 files changed, 24 insertions(+), 28 deletions(-) diff --git a/miplearn/solvers/gurobi.py b/miplearn/solvers/gurobi.py index 0b4c969..226e77c 100644 --- a/miplearn/solvers/gurobi.py +++ b/miplearn/solvers/gurobi.py @@ -147,10 +147,10 @@ class GurobiSolver(InternalSolver): opt_value = None if not self.is_infeasible(): opt_value = self.model.objVal - return { - "LP value": opt_value, - "LP log": log, - } + return LPSolveStats( + lp_value=opt_value, + lp_log=log, + ) @overrides def solve( diff --git a/miplearn/solvers/internal.py b/miplearn/solvers/internal.py index 69e9960..5ecd27c 100644 --- a/miplearn/solvers/internal.py +++ b/miplearn/solvers/internal.py @@ -4,6 +4,7 @@ import logging from abc import ABC, abstractmethod +from dataclasses import dataclass from typing import Any, Dict, List, Optional from overrides import EnforceOverrides @@ -11,7 +12,6 @@ from overrides import EnforceOverrides from miplearn.features import Constraint, Variable from miplearn.instance.base import Instance from miplearn.types import ( - LPSolveStats, IterationCallback, LazyCallback, MIPSolveStats, @@ -24,6 +24,12 @@ from miplearn.types import ( logger = logging.getLogger(__name__) +@dataclass +class LPSolveStats: + lp_log: Optional[str] = None + lp_value: Optional[float] = None + + class InternalSolver(ABC, EnforceOverrides): """ Abstract class representing the MIP solver used internally by LearningSolver. diff --git a/miplearn/solvers/learning.py b/miplearn/solvers/learning.py index de26229..fce1d77 100644 --- a/miplearn/solvers/learning.py +++ b/miplearn/solvers/learning.py @@ -178,10 +178,10 @@ class LearningSolver: logger.info("Solving root LP relaxation...") lp_stats = self.internal_solver.solve_lp(tee=tee) - stats.update(cast(LearningSolveStats, lp_stats)) + stats.update(cast(LearningSolveStats, lp_stats.__dict__)) training_sample.lp_solution = self.internal_solver.get_solution() - training_sample.lp_value = lp_stats["LP value"] - training_sample.lp_log = lp_stats["LP log"] + training_sample.lp_value = lp_stats.lp_value + training_sample.lp_log = lp_stats.lp_log logger.debug("Running after_solve_lp callbacks...") for component in self.components.values(): @@ -240,8 +240,6 @@ class LearningSolver: lazy_cb=lazy_cb, ) stats.update(cast(LearningSolveStats, mip_stats)) - if training_sample.lp_value is not None: - stats["LP value"] = training_sample.lp_value stats["Solver"] = "default" stats["Gap"] = self._compute_gap( ub=stats["Upper bound"], diff --git a/miplearn/solvers/pyomo/base.py b/miplearn/solvers/pyomo/base.py index a85b6e0..ba5e61d 100644 --- a/miplearn/solvers/pyomo/base.py +++ b/miplearn/solvers/pyomo/base.py @@ -87,10 +87,10 @@ class BasePyomoSolver(InternalSolver): if not self.is_infeasible(): opt_value = results["Problem"][0]["Lower bound"] self._has_lp_solution = True - return { - "LP value": opt_value, - "LP log": streams[0].getvalue(), - } + return LPSolveStats( + lp_value=opt_value, + lp_log=streams[0].getvalue(), + ) def _restore_integrality(self) -> None: for var in self._bin_vars: diff --git a/miplearn/solvers/tests/__init__.py b/miplearn/solvers/tests/__init__.py index 0d0fe77..464af2c 100644 --- a/miplearn/solvers/tests/__init__.py +++ b/miplearn/solvers/tests/__init__.py @@ -137,9 +137,10 @@ def run_basic_usage_tests(solver: InternalSolver) -> None: # Solve linear programming relaxation lp_stats = solver.solve_lp() assert not solver.is_infeasible() - assert lp_stats["LP value"] is not None - assert_equals(round(lp_stats["LP value"], 3), 1287.923) - assert len(lp_stats["LP log"]) > 100 + assert lp_stats.lp_value is not None + assert lp_stats.lp_log is not None + assert_equals(round(lp_stats.lp_value, 3), 1287.923) + assert len(lp_stats.lp_log) > 100 # Fetch variables (after-load) assert_equals( @@ -362,7 +363,7 @@ def run_infeasibility_tests(solver: InternalSolver) -> None: assert mip_stats["Lower bound"] is None lp_stats = solver.solve_lp() assert solver.get_solution() is None - assert lp_stats["LP value"] is None + assert lp_stats.lp_value is None def run_iteration_cb_tests(solver: InternalSolver) -> None: diff --git a/miplearn/types.py b/miplearn/types.py index fe66a22..5adea82 100644 --- a/miplearn/types.py +++ b/miplearn/types.py @@ -19,14 +19,6 @@ UserCutCallback = Callable[["InternalSolver", Any], None] VariableName = str Solution = Dict[VariableName, Optional[float]] -LPSolveStats = TypedDict( - "LPSolveStats", - { - "LP log": str, - "LP value": Optional[float], - }, -) - MIPSolveStats = TypedDict( "MIPSolveStats", { diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index 4d5303e..07cdea8 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -34,7 +34,6 @@ def test_benchmark() -> None: n_jobs=n_jobs, n_trials=2, ) - assert benchmark.results.values.shape == (12, 20) - benchmark.write_csv("/tmp/benchmark.csv") assert os.path.isfile("/tmp/benchmark.csv") + assert benchmark.results.values.shape == (12, 20)