mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Combine after_load, after_lp and after_mip into Sample dataclass
This commit is contained in:
@@ -106,6 +106,13 @@ class Features:
|
|||||||
mip_solve: Optional["MIPSolveStats"] = None
|
mip_solve: Optional["MIPSolveStats"] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Sample:
|
||||||
|
after_load: Optional[Features] = None
|
||||||
|
after_lp: Optional[Features] = None
|
||||||
|
after_mip: Optional[Features] = None
|
||||||
|
|
||||||
|
|
||||||
class FeaturesExtractor:
|
class FeaturesExtractor:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from typing import Any, List, Optional, Hashable, TYPE_CHECKING
|
|||||||
|
|
||||||
from overrides import EnforceOverrides
|
from overrides import EnforceOverrides
|
||||||
|
|
||||||
from miplearn.features import TrainingSample, Features
|
from miplearn.features import TrainingSample, Features, Sample
|
||||||
from miplearn.types import VariableName, Category
|
from miplearn.types import VariableName, Category
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -33,9 +33,7 @@ class Instance(ABC, EnforceOverrides):
|
|||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.training_data: List[TrainingSample] = []
|
self.training_data: List[TrainingSample] = []
|
||||||
self.features: Features = Features()
|
self.features: Features = Features()
|
||||||
self.features_after_load: List[Features] = []
|
self.samples: List[Sample] = []
|
||||||
self.features_after_lp: List[Features] = []
|
|
||||||
self.features_after_mip: List[Features] = []
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def to_model(self) -> Any:
|
def to_model(self) -> Any:
|
||||||
|
|||||||
@@ -123,9 +123,7 @@ class PickleGzInstance(Instance):
|
|||||||
self.instance = obj
|
self.instance = obj
|
||||||
self.features = self.instance.features
|
self.features = self.instance.features
|
||||||
self.training_data = self.instance.training_data
|
self.training_data = self.instance.training_data
|
||||||
self.features_after_load = self.instance.features_after_load
|
self.samples = self.instance.samples
|
||||||
self.features_after_lp = self.instance.features_after_lp
|
|
||||||
self.features_after_mip = self.instance.features_after_mip
|
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def free(self) -> None:
|
def free(self) -> None:
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from miplearn.components.dynamic_lazy import DynamicLazyConstraintsComponent
|
|||||||
from miplearn.components.dynamic_user_cuts import UserCutsComponent
|
from miplearn.components.dynamic_user_cuts import UserCutsComponent
|
||||||
from miplearn.components.objective import ObjectiveValueComponent
|
from miplearn.components.objective import ObjectiveValueComponent
|
||||||
from miplearn.components.primal import PrimalSolutionComponent
|
from miplearn.components.primal import PrimalSolutionComponent
|
||||||
from miplearn.features import FeaturesExtractor, TrainingSample
|
from miplearn.features import FeaturesExtractor, TrainingSample, Sample
|
||||||
from miplearn.instance.base import Instance
|
from miplearn.instance.base import Instance
|
||||||
from miplearn.instance.picklegz import PickleGzInstance
|
from miplearn.instance.picklegz import PickleGzInstance
|
||||||
from miplearn.solvers import _RedirectOutput
|
from miplearn.solvers import _RedirectOutput
|
||||||
@@ -140,6 +140,7 @@ class LearningSolver:
|
|||||||
# -------------------------------------------------------
|
# -------------------------------------------------------
|
||||||
training_sample = TrainingSample()
|
training_sample = TrainingSample()
|
||||||
instance.training_data += [training_sample]
|
instance.training_data += [training_sample]
|
||||||
|
sample = Sample()
|
||||||
|
|
||||||
# Initialize stats
|
# Initialize stats
|
||||||
# -------------------------------------------------------
|
# -------------------------------------------------------
|
||||||
@@ -158,7 +159,7 @@ class LearningSolver:
|
|||||||
logger.info("Extracting features (after-load)...")
|
logger.info("Extracting features (after-load)...")
|
||||||
features = FeaturesExtractor(self.internal_solver).extract(instance)
|
features = FeaturesExtractor(self.internal_solver).extract(instance)
|
||||||
instance.features.__dict__ = features.__dict__
|
instance.features.__dict__ = features.__dict__
|
||||||
instance.features_after_load.append(features)
|
sample.after_load = features
|
||||||
|
|
||||||
callback_args = (
|
callback_args = (
|
||||||
self,
|
self,
|
||||||
@@ -193,7 +194,7 @@ class LearningSolver:
|
|||||||
logger.info("Extracting features (after-lp)...")
|
logger.info("Extracting features (after-lp)...")
|
||||||
features = FeaturesExtractor(self.internal_solver).extract(instance)
|
features = FeaturesExtractor(self.internal_solver).extract(instance)
|
||||||
features.lp_solve = lp_stats
|
features.lp_solve = lp_stats
|
||||||
instance.features_after_lp.append(features)
|
sample.after_lp = features
|
||||||
|
|
||||||
# Callback wrappers
|
# Callback wrappers
|
||||||
# -------------------------------------------------------
|
# -------------------------------------------------------
|
||||||
@@ -254,7 +255,7 @@ class LearningSolver:
|
|||||||
logger.info("Extracting features (after-mip)...")
|
logger.info("Extracting features (after-mip)...")
|
||||||
features = FeaturesExtractor(self.internal_solver).extract(instance)
|
features = FeaturesExtractor(self.internal_solver).extract(instance)
|
||||||
features.mip_solve = mip_stats
|
features.mip_solve = mip_stats
|
||||||
instance.features_after_mip.append(features)
|
sample.after_mip = features
|
||||||
|
|
||||||
# Add some information to training_sample
|
# Add some information to training_sample
|
||||||
# -------------------------------------------------------
|
# -------------------------------------------------------
|
||||||
|
|||||||
Reference in New Issue
Block a user