mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Collect features 3 times (after-load, after-lp, after-mip)
This commit is contained in:
@@ -33,6 +33,9 @@ class Instance(ABC, EnforceOverrides):
|
|||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.training_data: List[TrainingSample] = []
|
self.training_data: List[TrainingSample] = []
|
||||||
self.features: Features = Features()
|
self.features: Features = Features()
|
||||||
|
self.features_after_load: List[Features] = []
|
||||||
|
self.features_after_lp: List[Features] = []
|
||||||
|
self.features_after_mip: List[Features] = []
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def to_model(self) -> Any:
|
def to_model(self) -> Any:
|
||||||
|
|||||||
@@ -123,6 +123,9 @@ class PickleGzInstance(Instance):
|
|||||||
self.instance = obj
|
self.instance = obj
|
||||||
self.features = self.instance.features
|
self.features = self.instance.features
|
||||||
self.training_data = self.instance.training_data
|
self.training_data = self.instance.training_data
|
||||||
|
self.features_after_load = self.instance.features_after_load
|
||||||
|
self.features_after_lp = self.instance.features_after_lp
|
||||||
|
self.features_after_mip = self.instance.features_after_mip
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def free(self) -> None:
|
def free(self) -> None:
|
||||||
|
|||||||
@@ -153,11 +153,12 @@ class LearningSolver:
|
|||||||
assert isinstance(self.internal_solver, InternalSolver)
|
assert isinstance(self.internal_solver, InternalSolver)
|
||||||
self.internal_solver.set_instance(instance, model)
|
self.internal_solver.set_instance(instance, model)
|
||||||
|
|
||||||
# Extract features
|
# Extract features (after-load)
|
||||||
# -------------------------------------------------------
|
# -------------------------------------------------------
|
||||||
if instance.features.instance is None:
|
logger.info("Extracting features (after-load)...")
|
||||||
logger.info("Extracting features...")
|
features = FeaturesExtractor(self.internal_solver).extract(instance)
|
||||||
FeaturesExtractor(self.internal_solver).extract(instance)
|
instance.features.__dict__ = features.__dict__
|
||||||
|
instance.features_after_load.append(features)
|
||||||
|
|
||||||
callback_args = (
|
callback_args = (
|
||||||
self,
|
self,
|
||||||
@@ -186,6 +187,12 @@ class LearningSolver:
|
|||||||
for component in self.components.values():
|
for component in self.components.values():
|
||||||
component.after_solve_lp(*callback_args)
|
component.after_solve_lp(*callback_args)
|
||||||
|
|
||||||
|
# Extract features (after-lp)
|
||||||
|
# -------------------------------------------------------
|
||||||
|
logger.info("Extracting features (after-lp)...")
|
||||||
|
features = FeaturesExtractor(self.internal_solver).extract(instance)
|
||||||
|
instance.features_after_lp.append(features)
|
||||||
|
|
||||||
# Callback wrappers
|
# Callback wrappers
|
||||||
# -------------------------------------------------------
|
# -------------------------------------------------------
|
||||||
def iteration_cb_wrapper() -> bool:
|
def iteration_cb_wrapper() -> bool:
|
||||||
@@ -242,6 +249,12 @@ class LearningSolver:
|
|||||||
)
|
)
|
||||||
stats["Mode"] = self.mode
|
stats["Mode"] = self.mode
|
||||||
|
|
||||||
|
# Extract features (after-mip)
|
||||||
|
# -------------------------------------------------------
|
||||||
|
logger.info("Extracting features (after-mip)...")
|
||||||
|
features = FeaturesExtractor(self.internal_solver).extract(instance)
|
||||||
|
instance.features_after_mip.append(features)
|
||||||
|
|
||||||
# Add some information to training_sample
|
# Add some information to training_sample
|
||||||
# -------------------------------------------------------
|
# -------------------------------------------------------
|
||||||
training_sample.lower_bound = stats["Lower bound"]
|
training_sample.lower_bound = stats["Lower bound"]
|
||||||
|
|||||||
Reference in New Issue
Block a user