mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Collect features 3 times (after-load, after-lp, after-mip)
This commit is contained in:
@@ -33,6 +33,9 @@ class Instance(ABC, EnforceOverrides):
|
||||
def __init__(self) -> None:
|
||||
self.training_data: List[TrainingSample] = []
|
||||
self.features: Features = Features()
|
||||
self.features_after_load: List[Features] = []
|
||||
self.features_after_lp: List[Features] = []
|
||||
self.features_after_mip: List[Features] = []
|
||||
|
||||
@abstractmethod
|
||||
def to_model(self) -> Any:
|
||||
|
||||
@@ -123,6 +123,9 @@ class PickleGzInstance(Instance):
|
||||
self.instance = obj
|
||||
self.features = self.instance.features
|
||||
self.training_data = self.instance.training_data
|
||||
self.features_after_load = self.instance.features_after_load
|
||||
self.features_after_lp = self.instance.features_after_lp
|
||||
self.features_after_mip = self.instance.features_after_mip
|
||||
|
||||
@overrides
|
||||
def free(self) -> None:
|
||||
|
||||
@@ -153,11 +153,12 @@ class LearningSolver:
|
||||
assert isinstance(self.internal_solver, InternalSolver)
|
||||
self.internal_solver.set_instance(instance, model)
|
||||
|
||||
# Extract features
|
||||
# Extract features (after-load)
|
||||
# -------------------------------------------------------
|
||||
if instance.features.instance is None:
|
||||
logger.info("Extracting features...")
|
||||
FeaturesExtractor(self.internal_solver).extract(instance)
|
||||
logger.info("Extracting features (after-load)...")
|
||||
features = FeaturesExtractor(self.internal_solver).extract(instance)
|
||||
instance.features.__dict__ = features.__dict__
|
||||
instance.features_after_load.append(features)
|
||||
|
||||
callback_args = (
|
||||
self,
|
||||
@@ -186,6 +187,12 @@ class LearningSolver:
|
||||
for component in self.components.values():
|
||||
component.after_solve_lp(*callback_args)
|
||||
|
||||
# Extract features (after-lp)
|
||||
# -------------------------------------------------------
|
||||
logger.info("Extracting features (after-lp)...")
|
||||
features = FeaturesExtractor(self.internal_solver).extract(instance)
|
||||
instance.features_after_lp.append(features)
|
||||
|
||||
# Callback wrappers
|
||||
# -------------------------------------------------------
|
||||
def iteration_cb_wrapper() -> bool:
|
||||
@@ -242,6 +249,12 @@ class LearningSolver:
|
||||
)
|
||||
stats["Mode"] = self.mode
|
||||
|
||||
# Extract features (after-mip)
|
||||
# -------------------------------------------------------
|
||||
logger.info("Extracting features (after-mip)...")
|
||||
features = FeaturesExtractor(self.internal_solver).extract(instance)
|
||||
instance.features_after_mip.append(features)
|
||||
|
||||
# Add some information to training_sample
|
||||
# -------------------------------------------------------
|
||||
training_sample.lower_bound = stats["Lower bound"]
|
||||
|
||||
Reference in New Issue
Block a user