Combine after_load, after_lp and after_mip into Sample dataclass

master
Alinson S. Xavier 5 years ago
parent 2d4ded1978
commit fde6dc5a60
No known key found for this signature in database
GPG Key ID: DCA0DAD4D2F58624

@ -106,6 +106,13 @@ class Features:
mip_solve: Optional["MIPSolveStats"] = None mip_solve: Optional["MIPSolveStats"] = None
@dataclass
class Sample:
after_load: Optional[Features] = None
after_lp: Optional[Features] = None
after_mip: Optional[Features] = None
class FeaturesExtractor: class FeaturesExtractor:
def __init__( def __init__(
self, self,

@ -8,7 +8,7 @@ from typing import Any, List, Optional, Hashable, TYPE_CHECKING
from overrides import EnforceOverrides from overrides import EnforceOverrides
from miplearn.features import TrainingSample, Features from miplearn.features import TrainingSample, Features, Sample
from miplearn.types import VariableName, Category from miplearn.types import VariableName, Category
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -33,9 +33,7 @@ class Instance(ABC, EnforceOverrides):
def __init__(self) -> None: def __init__(self) -> None:
self.training_data: List[TrainingSample] = [] self.training_data: List[TrainingSample] = []
self.features: Features = Features() self.features: Features = Features()
self.features_after_load: List[Features] = [] self.samples: List[Sample] = []
self.features_after_lp: List[Features] = []
self.features_after_mip: List[Features] = []
@abstractmethod @abstractmethod
def to_model(self) -> Any: def to_model(self) -> Any:

@ -123,9 +123,7 @@ class PickleGzInstance(Instance):
self.instance = obj self.instance = obj
self.features = self.instance.features self.features = self.instance.features
self.training_data = self.instance.training_data self.training_data = self.instance.training_data
self.features_after_load = self.instance.features_after_load self.samples = self.instance.samples
self.features_after_lp = self.instance.features_after_lp
self.features_after_mip = self.instance.features_after_mip
@overrides @overrides
def free(self) -> None: def free(self) -> None:

@ -13,7 +13,7 @@ from miplearn.components.dynamic_lazy import DynamicLazyConstraintsComponent
from miplearn.components.dynamic_user_cuts import UserCutsComponent from miplearn.components.dynamic_user_cuts import UserCutsComponent
from miplearn.components.objective import ObjectiveValueComponent from miplearn.components.objective import ObjectiveValueComponent
from miplearn.components.primal import PrimalSolutionComponent from miplearn.components.primal import PrimalSolutionComponent
from miplearn.features import FeaturesExtractor, TrainingSample from miplearn.features import FeaturesExtractor, TrainingSample, Sample
from miplearn.instance.base import Instance from miplearn.instance.base import Instance
from miplearn.instance.picklegz import PickleGzInstance from miplearn.instance.picklegz import PickleGzInstance
from miplearn.solvers import _RedirectOutput from miplearn.solvers import _RedirectOutput
@ -140,6 +140,7 @@ class LearningSolver:
# ------------------------------------------------------- # -------------------------------------------------------
training_sample = TrainingSample() training_sample = TrainingSample()
instance.training_data += [training_sample] instance.training_data += [training_sample]
sample = Sample()
# Initialize stats # Initialize stats
# ------------------------------------------------------- # -------------------------------------------------------
@ -158,7 +159,7 @@ class LearningSolver:
logger.info("Extracting features (after-load)...") logger.info("Extracting features (after-load)...")
features = FeaturesExtractor(self.internal_solver).extract(instance) features = FeaturesExtractor(self.internal_solver).extract(instance)
instance.features.__dict__ = features.__dict__ instance.features.__dict__ = features.__dict__
instance.features_after_load.append(features) sample.after_load = features
callback_args = ( callback_args = (
self, self,
@ -193,7 +194,7 @@ class LearningSolver:
logger.info("Extracting features (after-lp)...") logger.info("Extracting features (after-lp)...")
features = FeaturesExtractor(self.internal_solver).extract(instance) features = FeaturesExtractor(self.internal_solver).extract(instance)
features.lp_solve = lp_stats features.lp_solve = lp_stats
instance.features_after_lp.append(features) sample.after_lp = features
# Callback wrappers # Callback wrappers
# ------------------------------------------------------- # -------------------------------------------------------
@ -254,7 +255,7 @@ class LearningSolver:
logger.info("Extracting features (after-mip)...") logger.info("Extracting features (after-mip)...")
features = FeaturesExtractor(self.internal_solver).extract(instance) features = FeaturesExtractor(self.internal_solver).extract(instance)
features.mip_solve = mip_stats features.mip_solve = mip_stats
instance.features_after_mip.append(features) sample.after_mip = features
# Add some information to training_sample # Add some information to training_sample
# ------------------------------------------------------- # -------------------------------------------------------

Loading…
Cancel
Save