diff --git a/miplearn/instance/base.py b/miplearn/instance/base.py index 2466928..9edae50 100644 --- a/miplearn/instance/base.py +++ b/miplearn/instance/base.py @@ -33,6 +33,9 @@ class Instance(ABC, EnforceOverrides): def __init__(self) -> None: self.training_data: List[TrainingSample] = [] self.features: Features = Features() + self.features_after_load: List[Features] = [] + self.features_after_lp: List[Features] = [] + self.features_after_mip: List[Features] = [] @abstractmethod def to_model(self) -> Any: diff --git a/miplearn/instance/picklegz.py b/miplearn/instance/picklegz.py index 7d165e7..1b27438 100644 --- a/miplearn/instance/picklegz.py +++ b/miplearn/instance/picklegz.py @@ -123,6 +123,9 @@ class PickleGzInstance(Instance): self.instance = obj self.features = self.instance.features self.training_data = self.instance.training_data + self.features_after_load = self.instance.features_after_load + self.features_after_lp = self.instance.features_after_lp + self.features_after_mip = self.instance.features_after_mip @overrides def free(self) -> None: diff --git a/miplearn/solvers/learning.py b/miplearn/solvers/learning.py index dbf5a4b..de26229 100644 --- a/miplearn/solvers/learning.py +++ b/miplearn/solvers/learning.py @@ -153,11 +153,12 @@ class LearningSolver: assert isinstance(self.internal_solver, InternalSolver) self.internal_solver.set_instance(instance, model) - # Extract features + # Extract features (after-load) # ------------------------------------------------------- - if instance.features.instance is None: - logger.info("Extracting features...") - FeaturesExtractor(self.internal_solver).extract(instance) + logger.info("Extracting features (after-load)...") + features = FeaturesExtractor(self.internal_solver).extract(instance) + instance.features.__dict__ = features.__dict__ + instance.features_after_load.append(features) callback_args = ( self, @@ -186,6 +187,12 @@ class LearningSolver: for component in self.components.values(): component.after_solve_lp(*callback_args) + # Extract features (after-lp) + # ------------------------------------------------------- + logger.info("Extracting features (after-lp)...") + features = FeaturesExtractor(self.internal_solver).extract(instance) + instance.features_after_lp.append(features) + # Callback wrappers # ------------------------------------------------------- def iteration_cb_wrapper() -> bool: @@ -242,6 +249,12 @@ class LearningSolver: ) stats["Mode"] = self.mode + # Extract features (after-mip) + # ------------------------------------------------------- + logger.info("Extracting features (after-mip)...") + features = FeaturesExtractor(self.internal_solver).extract(instance) + instance.features_after_mip.append(features) + # Add some information to training_sample # ------------------------------------------------------- training_sample.lower_bound = stats["Lower bound"]