mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Add lp_stats to after-lp features
This commit is contained in:
@@ -8,11 +8,12 @@ from dataclasses import dataclass
|
||||
from math import log, isfinite
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Set, List, Hashable
|
||||
|
||||
from miplearn.types import Solution, VariableName, Category
|
||||
import numpy as np
|
||||
|
||||
from miplearn.types import Solution, Category
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from miplearn.solvers.internal import InternalSolver
|
||||
from miplearn.solvers.internal import InternalSolver, LPSolveStats
|
||||
from miplearn.instance.base import Instance
|
||||
|
||||
|
||||
@@ -101,6 +102,7 @@ class Features:
|
||||
instance: Optional[InstanceFeatures] = None
|
||||
variables: Optional[Dict[str, Variable]] = None
|
||||
constraints: Optional[Dict[str, Constraint]] = None
|
||||
lp_solve: Optional["LPSolveStats"] = None
|
||||
|
||||
|
||||
class FeaturesExtractor:
|
||||
|
||||
@@ -150,6 +150,7 @@ class GurobiSolver(InternalSolver):
|
||||
return LPSolveStats(
|
||||
lp_value=opt_value,
|
||||
lp_log=log,
|
||||
lp_wallclock_time=self.model.runtime,
|
||||
)
|
||||
|
||||
@overrides
|
||||
|
||||
@@ -28,6 +28,7 @@ logger = logging.getLogger(__name__)
|
||||
class LPSolveStats:
|
||||
lp_log: Optional[str] = None
|
||||
lp_value: Optional[float] = None
|
||||
lp_wallclock_time: Optional[float] = None
|
||||
|
||||
|
||||
class InternalSolver(ABC, EnforceOverrides):
|
||||
|
||||
@@ -191,6 +191,7 @@ class LearningSolver:
|
||||
# -------------------------------------------------------
|
||||
logger.info("Extracting features (after-lp)...")
|
||||
features = FeaturesExtractor(self.internal_solver).extract(instance)
|
||||
features.lp_solve = lp_stats
|
||||
instance.features_after_lp.append(features)
|
||||
|
||||
# Callback wrappers
|
||||
|
||||
@@ -90,6 +90,7 @@ class BasePyomoSolver(InternalSolver):
|
||||
return LPSolveStats(
|
||||
lp_value=opt_value,
|
||||
lp_log=streams[0].getvalue(),
|
||||
lp_wallclock_time=results["Solver"][0]["Wallclock time"],
|
||||
)
|
||||
|
||||
def _restore_integrality(self) -> None:
|
||||
|
||||
@@ -138,9 +138,11 @@ def run_basic_usage_tests(solver: InternalSolver) -> None:
|
||||
lp_stats = solver.solve_lp()
|
||||
assert not solver.is_infeasible()
|
||||
assert lp_stats.lp_value is not None
|
||||
assert lp_stats.lp_log is not None
|
||||
assert_equals(round(lp_stats.lp_value, 3), 1287.923)
|
||||
assert lp_stats.lp_log is not None
|
||||
assert len(lp_stats.lp_log) > 100
|
||||
assert lp_stats.lp_wallclock_time is not None
|
||||
assert lp_stats.lp_wallclock_time > 0
|
||||
|
||||
# Fetch variables (after-load)
|
||||
assert_equals(
|
||||
|
||||
Reference in New Issue
Block a user