diff --git a/miplearn/features.py b/miplearn/features.py index 5f02626..ea0ea90 100644 --- a/miplearn/features.py +++ b/miplearn/features.py @@ -8,11 +8,12 @@ from dataclasses import dataclass from math import log, isfinite from typing import TYPE_CHECKING, Dict, Optional, Set, List, Hashable -from miplearn.types import Solution, VariableName, Category import numpy as np +from miplearn.types import Solution, Category + if TYPE_CHECKING: - from miplearn.solvers.internal import InternalSolver + from miplearn.solvers.internal import InternalSolver, LPSolveStats from miplearn.instance.base import Instance @@ -101,6 +102,7 @@ class Features: instance: Optional[InstanceFeatures] = None variables: Optional[Dict[str, Variable]] = None constraints: Optional[Dict[str, Constraint]] = None + lp_solve: Optional["LPSolveStats"] = None class FeaturesExtractor: diff --git a/miplearn/solvers/gurobi.py b/miplearn/solvers/gurobi.py index 226e77c..8c226db 100644 --- a/miplearn/solvers/gurobi.py +++ b/miplearn/solvers/gurobi.py @@ -150,6 +150,7 @@ class GurobiSolver(InternalSolver): return LPSolveStats( lp_value=opt_value, lp_log=log, + lp_wallclock_time=self.model.runtime, ) @overrides diff --git a/miplearn/solvers/internal.py b/miplearn/solvers/internal.py index 5ecd27c..95b02c4 100644 --- a/miplearn/solvers/internal.py +++ b/miplearn/solvers/internal.py @@ -28,6 +28,7 @@ logger = logging.getLogger(__name__) class LPSolveStats: lp_log: Optional[str] = None lp_value: Optional[float] = None + lp_wallclock_time: Optional[float] = None class InternalSolver(ABC, EnforceOverrides): diff --git a/miplearn/solvers/learning.py b/miplearn/solvers/learning.py index fce1d77..f1cefc3 100644 --- a/miplearn/solvers/learning.py +++ b/miplearn/solvers/learning.py @@ -191,6 +191,7 @@ class LearningSolver: # ------------------------------------------------------- logger.info("Extracting features (after-lp)...") features = FeaturesExtractor(self.internal_solver).extract(instance) + features.lp_solve = lp_stats instance.features_after_lp.append(features) # Callback wrappers diff --git a/miplearn/solvers/pyomo/base.py b/miplearn/solvers/pyomo/base.py index ba5e61d..5c7f905 100644 --- a/miplearn/solvers/pyomo/base.py +++ b/miplearn/solvers/pyomo/base.py @@ -90,6 +90,7 @@ class BasePyomoSolver(InternalSolver): return LPSolveStats( lp_value=opt_value, lp_log=streams[0].getvalue(), + lp_wallclock_time=results["Solver"][0]["Wallclock time"], ) def _restore_integrality(self) -> None: diff --git a/miplearn/solvers/tests/__init__.py b/miplearn/solvers/tests/__init__.py index 464af2c..1e53d85 100644 --- a/miplearn/solvers/tests/__init__.py +++ b/miplearn/solvers/tests/__init__.py @@ -138,9 +138,11 @@ def run_basic_usage_tests(solver: InternalSolver) -> None: lp_stats = solver.solve_lp() assert not solver.is_infeasible() assert lp_stats.lp_value is not None - assert lp_stats.lp_log is not None assert_equals(round(lp_stats.lp_value, 3), 1287.923) + assert lp_stats.lp_log is not None assert len(lp_stats.lp_log) > 100 + assert lp_stats.lp_wallclock_time is not None + assert lp_stats.lp_wallclock_time > 0 # Fetch variables (after-load) assert_equals( diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index 07cdea8..1183aa4 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -36,4 +36,4 @@ def test_benchmark() -> None: ) benchmark.write_csv("/tmp/benchmark.csv") assert os.path.isfile("/tmp/benchmark.csv") - assert benchmark.results.values.shape == (12, 20) + assert benchmark.results.values.shape == (12, 21)