|
|
|
@ -38,20 +38,19 @@ class UserCutsComponent(Component):
|
|
|
|
|
self.n_added_in_callback = 0
|
|
|
|
|
|
|
|
|
|
@overrides
|
|
|
|
|
def before_solve_mip_old(
|
|
|
|
|
def before_solve_mip(
|
|
|
|
|
self,
|
|
|
|
|
solver: "LearningSolver",
|
|
|
|
|
instance: "Instance",
|
|
|
|
|
model: Any,
|
|
|
|
|
stats: LearningSolveStats,
|
|
|
|
|
features: Features,
|
|
|
|
|
training_data: TrainingSample,
|
|
|
|
|
sample: Sample,
|
|
|
|
|
) -> None:
|
|
|
|
|
assert solver.internal_solver is not None
|
|
|
|
|
self.enforced.clear()
|
|
|
|
|
self.n_added_in_callback = 0
|
|
|
|
|
logger.info("Predicting violated user cuts...")
|
|
|
|
|
cids = self.dynamic.sample_predict_old(instance, training_data)
|
|
|
|
|
cids = self.dynamic.sample_predict(instance, sample)
|
|
|
|
|
logger.info("Enforcing %d user cuts ahead-of-time..." % len(cids))
|
|
|
|
|
for cid in cids:
|
|
|
|
|
instance.enforce_user_cut(solver.internal_solver, model, cid)
|
|
|
|
@ -80,34 +79,27 @@ class UserCutsComponent(Component):
|
|
|
|
|
logger.debug(f"Added {len(cids)} violated user cuts")
|
|
|
|
|
|
|
|
|
|
@overrides
|
|
|
|
|
def after_solve_mip_old(
|
|
|
|
|
def after_solve_mip(
|
|
|
|
|
self,
|
|
|
|
|
solver: "LearningSolver",
|
|
|
|
|
instance: "Instance",
|
|
|
|
|
model: Any,
|
|
|
|
|
stats: LearningSolveStats,
|
|
|
|
|
features: Features,
|
|
|
|
|
training_data: TrainingSample,
|
|
|
|
|
sample: Sample,
|
|
|
|
|
) -> None:
|
|
|
|
|
training_data.user_cuts_enforced = set(self.enforced)
|
|
|
|
|
assert sample.after_mip is not None
|
|
|
|
|
assert sample.after_mip.extra is not None
|
|
|
|
|
sample.after_mip.extra["user_cuts_enforced"] = set(self.enforced)
|
|
|
|
|
stats["UserCuts: Added in callback"] = self.n_added_in_callback
|
|
|
|
|
if self.n_added_in_callback > 0:
|
|
|
|
|
logger.info(f"{self.n_added_in_callback} user cuts added in callback")
|
|
|
|
|
|
|
|
|
|
# Delegate ML methods to self.dynamic
|
|
|
|
|
# -------------------------------------------------------------------
|
|
|
|
|
@overrides
|
|
|
|
|
def sample_xy_old(
|
|
|
|
|
self,
|
|
|
|
|
instance: "Instance",
|
|
|
|
|
sample: TrainingSample,
|
|
|
|
|
) -> Tuple[Dict, Dict]:
|
|
|
|
|
return self.dynamic.sample_xy_old(instance, sample)
|
|
|
|
|
|
|
|
|
|
@overrides
|
|
|
|
|
def sample_xy(
|
|
|
|
|
self,
|
|
|
|
|
instance: Optional[Instance],
|
|
|
|
|
instance: "Instance",
|
|
|
|
|
sample: Sample,
|
|
|
|
|
) -> Tuple[Dict, Dict]:
|
|
|
|
|
return self.dynamic.sample_xy(instance, sample)
|
|
|
|
@ -115,13 +107,13 @@ class UserCutsComponent(Component):
|
|
|
|
|
def sample_predict(
|
|
|
|
|
self,
|
|
|
|
|
instance: "Instance",
|
|
|
|
|
sample: TrainingSample,
|
|
|
|
|
sample: Sample,
|
|
|
|
|
) -> List[Hashable]:
|
|
|
|
|
return self.dynamic.sample_predict(instance, sample)
|
|
|
|
|
|
|
|
|
|
@overrides
|
|
|
|
|
def fit_old(self, training_instances: List["Instance"]) -> None:
|
|
|
|
|
self.dynamic.fit_old(training_instances)
|
|
|
|
|
def fit(self, training_instances: List["Instance"]) -> None:
|
|
|
|
|
self.dynamic.fit(training_instances)
|
|
|
|
|
|
|
|
|
|
@overrides
|
|
|
|
|
def fit_xy(
|
|
|
|
@ -132,9 +124,9 @@ class UserCutsComponent(Component):
|
|
|
|
|
self.dynamic.fit_xy(x, y)
|
|
|
|
|
|
|
|
|
|
@overrides
|
|
|
|
|
def sample_evaluate_old(
|
|
|
|
|
def sample_evaluate(
|
|
|
|
|
self,
|
|
|
|
|
instance: "Instance",
|
|
|
|
|
sample: TrainingSample,
|
|
|
|
|
sample: Sample,
|
|
|
|
|
) -> Dict[Hashable, Dict[str, float]]:
|
|
|
|
|
return self.dynamic.sample_evaluate_old(instance, sample)
|
|
|
|
|
return self.dynamic.sample_evaluate(instance, sample)
|
|
|
|
|