From fde6dc5a6050a6cbf4e9a1643a16039a6b4c3a01 Mon Sep 17 00:00:00 2001 From: "Alinson S. Xavier" Date: Sun, 11 Apr 2021 17:20:17 -0500 Subject: [PATCH] Combine after_load, after_lp and after_mip into Sample dataclass --- miplearn/features.py | 7 +++++++ miplearn/instance/base.py | 6 ++---- miplearn/instance/picklegz.py | 4 +--- miplearn/solvers/learning.py | 9 +++++---- 4 files changed, 15 insertions(+), 11 deletions(-) diff --git a/miplearn/features.py b/miplearn/features.py index c25f0f5..3590eee 100644 --- a/miplearn/features.py +++ b/miplearn/features.py @@ -106,6 +106,13 @@ class Features: mip_solve: Optional["MIPSolveStats"] = None +@dataclass +class Sample: + after_load: Optional[Features] = None + after_lp: Optional[Features] = None + after_mip: Optional[Features] = None + + class FeaturesExtractor: def __init__( self, diff --git a/miplearn/instance/base.py b/miplearn/instance/base.py index 9edae50..4778376 100644 --- a/miplearn/instance/base.py +++ b/miplearn/instance/base.py @@ -8,7 +8,7 @@ from typing import Any, List, Optional, Hashable, TYPE_CHECKING from overrides import EnforceOverrides -from miplearn.features import TrainingSample, Features +from miplearn.features import TrainingSample, Features, Sample from miplearn.types import VariableName, Category logger = logging.getLogger(__name__) @@ -33,9 +33,7 @@ class Instance(ABC, EnforceOverrides): def __init__(self) -> None: self.training_data: List[TrainingSample] = [] self.features: Features = Features() - self.features_after_load: List[Features] = [] - self.features_after_lp: List[Features] = [] - self.features_after_mip: List[Features] = [] + self.samples: List[Sample] = [] @abstractmethod def to_model(self) -> Any: diff --git a/miplearn/instance/picklegz.py b/miplearn/instance/picklegz.py index 1b27438..dbcd197 100644 --- a/miplearn/instance/picklegz.py +++ b/miplearn/instance/picklegz.py @@ -123,9 +123,7 @@ class PickleGzInstance(Instance): self.instance = obj self.features = self.instance.features self.training_data = self.instance.training_data - self.features_after_load = self.instance.features_after_load - self.features_after_lp = self.instance.features_after_lp - self.features_after_mip = self.instance.features_after_mip + self.samples = self.instance.samples @overrides def free(self) -> None: diff --git a/miplearn/solvers/learning.py b/miplearn/solvers/learning.py index 6a70696..9dc36da 100644 --- a/miplearn/solvers/learning.py +++ b/miplearn/solvers/learning.py @@ -13,7 +13,7 @@ from miplearn.components.dynamic_lazy import DynamicLazyConstraintsComponent from miplearn.components.dynamic_user_cuts import UserCutsComponent from miplearn.components.objective import ObjectiveValueComponent from miplearn.components.primal import PrimalSolutionComponent -from miplearn.features import FeaturesExtractor, TrainingSample +from miplearn.features import FeaturesExtractor, TrainingSample, Sample from miplearn.instance.base import Instance from miplearn.instance.picklegz import PickleGzInstance from miplearn.solvers import _RedirectOutput @@ -140,6 +140,7 @@ class LearningSolver: # ------------------------------------------------------- training_sample = TrainingSample() instance.training_data += [training_sample] + sample = Sample() # Initialize stats # ------------------------------------------------------- @@ -158,7 +159,7 @@ class LearningSolver: logger.info("Extracting features (after-load)...") features = FeaturesExtractor(self.internal_solver).extract(instance) instance.features.__dict__ = features.__dict__ - instance.features_after_load.append(features) + sample.after_load = features callback_args = ( self, @@ -193,7 +194,7 @@ class LearningSolver: logger.info("Extracting features (after-lp)...") features = FeaturesExtractor(self.internal_solver).extract(instance) features.lp_solve = lp_stats - instance.features_after_lp.append(features) + sample.after_lp = features # Callback wrappers # ------------------------------------------------------- @@ -254,7 +255,7 @@ class LearningSolver: logger.info("Extracting features (after-mip)...") features = FeaturesExtractor(self.internal_solver).extract(instance) features.mip_solve = mip_stats - instance.features_after_mip.append(features) + sample.after_mip = features # Add some information to training_sample # -------------------------------------------------------