Update UserCutsComponent

This commit is contained in:
2021-04-13 09:05:34 -05:00
parent a4433916e5
commit c26b852c67
4 changed files with 22 additions and 27 deletions

View File

@@ -41,7 +41,7 @@ class DynamicLazyConstraintsComponent(Component):
self.classifiers = self.dynamic.classifiers
self.thresholds = self.dynamic.thresholds
self.known_cids = self.dynamic.known_cids
self.lazy_enforced: Set[str] = set()
self.lazy_enforced: Set[Hashable] = set()
@staticmethod
def enforce(

View File

@@ -38,20 +38,19 @@ class UserCutsComponent(Component):
self.n_added_in_callback = 0
@overrides
def before_solve_mip_old(
def before_solve_mip(
self,
solver: "LearningSolver",
instance: "Instance",
model: Any,
stats: LearningSolveStats,
features: Features,
training_data: TrainingSample,
sample: Sample,
) -> None:
assert solver.internal_solver is not None
self.enforced.clear()
self.n_added_in_callback = 0
logger.info("Predicting violated user cuts...")
cids = self.dynamic.sample_predict_old(instance, training_data)
cids = self.dynamic.sample_predict(instance, sample)
logger.info("Enforcing %d user cuts ahead-of-time..." % len(cids))
for cid in cids:
instance.enforce_user_cut(solver.internal_solver, model, cid)
@@ -80,34 +79,27 @@ class UserCutsComponent(Component):
logger.debug(f"Added {len(cids)} violated user cuts")
@overrides
def after_solve_mip_old(
def after_solve_mip(
self,
solver: "LearningSolver",
instance: "Instance",
model: Any,
stats: LearningSolveStats,
features: Features,
training_data: TrainingSample,
sample: Sample,
) -> None:
training_data.user_cuts_enforced = set(self.enforced)
assert sample.after_mip is not None
assert sample.after_mip.extra is not None
sample.after_mip.extra["user_cuts_enforced"] = set(self.enforced)
stats["UserCuts: Added in callback"] = self.n_added_in_callback
if self.n_added_in_callback > 0:
logger.info(f"{self.n_added_in_callback} user cuts added in callback")
# Delegate ML methods to self.dynamic
# -------------------------------------------------------------------
@overrides
def sample_xy_old(
self,
instance: "Instance",
sample: TrainingSample,
) -> Tuple[Dict, Dict]:
return self.dynamic.sample_xy_old(instance, sample)
@overrides
def sample_xy(
self,
instance: Optional[Instance],
instance: "Instance",
sample: Sample,
) -> Tuple[Dict, Dict]:
return self.dynamic.sample_xy(instance, sample)
@@ -115,13 +107,13 @@ class UserCutsComponent(Component):
def sample_predict(
self,
instance: "Instance",
sample: TrainingSample,
sample: Sample,
) -> List[Hashable]:
return self.dynamic.sample_predict(instance, sample)
@overrides
def fit_old(self, training_instances: List["Instance"]) -> None:
self.dynamic.fit_old(training_instances)
def fit(self, training_instances: List["Instance"]) -> None:
self.dynamic.fit(training_instances)
@overrides
def fit_xy(
@@ -132,9 +124,9 @@ class UserCutsComponent(Component):
self.dynamic.fit_xy(x, y)
@overrides
def sample_evaluate_old(
def sample_evaluate(
self,
instance: "Instance",
sample: TrainingSample,
sample: Sample,
) -> Dict[Hashable, Dict[str, float]]:
return self.dynamic.sample_evaluate_old(instance, sample)
return self.dynamic.sample_evaluate(instance, sample)