Collect features 3 times (after-load, after-lp, after-mip)

master
Alinson S. Xavier 5 years ago
parent d85a63f869
commit 6afdf2ed55
No known key found for this signature in database
GPG Key ID: DCA0DAD4D2F58624

@ -33,6 +33,9 @@ class Instance(ABC, EnforceOverrides):
def __init__(self) -> None: def __init__(self) -> None:
self.training_data: List[TrainingSample] = [] self.training_data: List[TrainingSample] = []
self.features: Features = Features() self.features: Features = Features()
self.features_after_load: List[Features] = []
self.features_after_lp: List[Features] = []
self.features_after_mip: List[Features] = []
@abstractmethod @abstractmethod
def to_model(self) -> Any: def to_model(self) -> Any:

@ -123,6 +123,9 @@ class PickleGzInstance(Instance):
self.instance = obj self.instance = obj
self.features = self.instance.features self.features = self.instance.features
self.training_data = self.instance.training_data self.training_data = self.instance.training_data
self.features_after_load = self.instance.features_after_load
self.features_after_lp = self.instance.features_after_lp
self.features_after_mip = self.instance.features_after_mip
@overrides @overrides
def free(self) -> None: def free(self) -> None:

@ -153,11 +153,12 @@ class LearningSolver:
assert isinstance(self.internal_solver, InternalSolver) assert isinstance(self.internal_solver, InternalSolver)
self.internal_solver.set_instance(instance, model) self.internal_solver.set_instance(instance, model)
# Extract features # Extract features (after-load)
# ------------------------------------------------------- # -------------------------------------------------------
if instance.features.instance is None: logger.info("Extracting features (after-load)...")
logger.info("Extracting features...") features = FeaturesExtractor(self.internal_solver).extract(instance)
FeaturesExtractor(self.internal_solver).extract(instance) instance.features.__dict__ = features.__dict__
instance.features_after_load.append(features)
callback_args = ( callback_args = (
self, self,
@ -186,6 +187,12 @@ class LearningSolver:
for component in self.components.values(): for component in self.components.values():
component.after_solve_lp(*callback_args) component.after_solve_lp(*callback_args)
# Extract features (after-lp)
# -------------------------------------------------------
logger.info("Extracting features (after-lp)...")
features = FeaturesExtractor(self.internal_solver).extract(instance)
instance.features_after_lp.append(features)
# Callback wrappers # Callback wrappers
# ------------------------------------------------------- # -------------------------------------------------------
def iteration_cb_wrapper() -> bool: def iteration_cb_wrapper() -> bool:
@ -242,6 +249,12 @@ class LearningSolver:
) )
stats["Mode"] = self.mode stats["Mode"] = self.mode
# Extract features (after-mip)
# -------------------------------------------------------
logger.info("Extracting features (after-mip)...")
features = FeaturesExtractor(self.internal_solver).extract(instance)
instance.features_after_mip.append(features)
# Add some information to training_sample # Add some information to training_sample
# ------------------------------------------------------- # -------------------------------------------------------
training_sample.lower_bound = stats["Lower bound"] training_sample.lower_bound = stats["Lower bound"]

Loading…
Cancel
Save