Add lp_stats to after-lp features

master
Alinson S. Xavier 5 years ago
parent 945f6a091c
commit 2bc1e21f8e
No known key found for this signature in database
GPG Key ID: DCA0DAD4D2F58624

@ -8,11 +8,12 @@ from dataclasses import dataclass
from math import log, isfinite
from typing import TYPE_CHECKING, Dict, Optional, Set, List, Hashable
from miplearn.types import Solution, VariableName, Category
import numpy as np
from miplearn.types import Solution, Category
if TYPE_CHECKING:
from miplearn.solvers.internal import InternalSolver
from miplearn.solvers.internal import InternalSolver, LPSolveStats
from miplearn.instance.base import Instance
@ -101,6 +102,7 @@ class Features:
instance: Optional[InstanceFeatures] = None
variables: Optional[Dict[str, Variable]] = None
constraints: Optional[Dict[str, Constraint]] = None
lp_solve: Optional["LPSolveStats"] = None
class FeaturesExtractor:

@ -150,6 +150,7 @@ class GurobiSolver(InternalSolver):
return LPSolveStats(
lp_value=opt_value,
lp_log=log,
lp_wallclock_time=self.model.runtime,
)
@overrides

@ -28,6 +28,7 @@ logger = logging.getLogger(__name__)
class LPSolveStats:
lp_log: Optional[str] = None
lp_value: Optional[float] = None
lp_wallclock_time: Optional[float] = None
class InternalSolver(ABC, EnforceOverrides):

@ -191,6 +191,7 @@ class LearningSolver:
# -------------------------------------------------------
logger.info("Extracting features (after-lp)...")
features = FeaturesExtractor(self.internal_solver).extract(instance)
features.lp_solve = lp_stats
instance.features_after_lp.append(features)
# Callback wrappers

@ -90,6 +90,7 @@ class BasePyomoSolver(InternalSolver):
return LPSolveStats(
lp_value=opt_value,
lp_log=streams[0].getvalue(),
lp_wallclock_time=results["Solver"][0]["Wallclock time"],
)
def _restore_integrality(self) -> None:

@ -138,9 +138,11 @@ def run_basic_usage_tests(solver: InternalSolver) -> None:
lp_stats = solver.solve_lp()
assert not solver.is_infeasible()
assert lp_stats.lp_value is not None
assert lp_stats.lp_log is not None
assert_equals(round(lp_stats.lp_value, 3), 1287.923)
assert lp_stats.lp_log is not None
assert len(lp_stats.lp_log) > 100
assert lp_stats.lp_wallclock_time is not None
assert lp_stats.lp_wallclock_time > 0
# Fetch variables (after-load)
assert_equals(

@ -36,4 +36,4 @@ def test_benchmark() -> None:
)
benchmark.write_csv("/tmp/benchmark.csv")
assert os.path.isfile("/tmp/benchmark.csv")
assert benchmark.results.values.shape == (12, 20)
assert benchmark.results.values.shape == (12, 21)

Loading…
Cancel
Save