mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Update UserCutsComponent
This commit is contained in:
@@ -41,7 +41,7 @@ class DynamicLazyConstraintsComponent(Component):
|
|||||||
self.classifiers = self.dynamic.classifiers
|
self.classifiers = self.dynamic.classifiers
|
||||||
self.thresholds = self.dynamic.thresholds
|
self.thresholds = self.dynamic.thresholds
|
||||||
self.known_cids = self.dynamic.known_cids
|
self.known_cids = self.dynamic.known_cids
|
||||||
self.lazy_enforced: Set[str] = set()
|
self.lazy_enforced: Set[Hashable] = set()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def enforce(
|
def enforce(
|
||||||
|
|||||||
@@ -38,20 +38,19 @@ class UserCutsComponent(Component):
|
|||||||
self.n_added_in_callback = 0
|
self.n_added_in_callback = 0
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def before_solve_mip_old(
|
def before_solve_mip(
|
||||||
self,
|
self,
|
||||||
solver: "LearningSolver",
|
solver: "LearningSolver",
|
||||||
instance: "Instance",
|
instance: "Instance",
|
||||||
model: Any,
|
model: Any,
|
||||||
stats: LearningSolveStats,
|
stats: LearningSolveStats,
|
||||||
features: Features,
|
sample: Sample,
|
||||||
training_data: TrainingSample,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
assert solver.internal_solver is not None
|
assert solver.internal_solver is not None
|
||||||
self.enforced.clear()
|
self.enforced.clear()
|
||||||
self.n_added_in_callback = 0
|
self.n_added_in_callback = 0
|
||||||
logger.info("Predicting violated user cuts...")
|
logger.info("Predicting violated user cuts...")
|
||||||
cids = self.dynamic.sample_predict_old(instance, training_data)
|
cids = self.dynamic.sample_predict(instance, sample)
|
||||||
logger.info("Enforcing %d user cuts ahead-of-time..." % len(cids))
|
logger.info("Enforcing %d user cuts ahead-of-time..." % len(cids))
|
||||||
for cid in cids:
|
for cid in cids:
|
||||||
instance.enforce_user_cut(solver.internal_solver, model, cid)
|
instance.enforce_user_cut(solver.internal_solver, model, cid)
|
||||||
@@ -80,34 +79,27 @@ class UserCutsComponent(Component):
|
|||||||
logger.debug(f"Added {len(cids)} violated user cuts")
|
logger.debug(f"Added {len(cids)} violated user cuts")
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def after_solve_mip_old(
|
def after_solve_mip(
|
||||||
self,
|
self,
|
||||||
solver: "LearningSolver",
|
solver: "LearningSolver",
|
||||||
instance: "Instance",
|
instance: "Instance",
|
||||||
model: Any,
|
model: Any,
|
||||||
stats: LearningSolveStats,
|
stats: LearningSolveStats,
|
||||||
features: Features,
|
sample: Sample,
|
||||||
training_data: TrainingSample,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
training_data.user_cuts_enforced = set(self.enforced)
|
assert sample.after_mip is not None
|
||||||
|
assert sample.after_mip.extra is not None
|
||||||
|
sample.after_mip.extra["user_cuts_enforced"] = set(self.enforced)
|
||||||
stats["UserCuts: Added in callback"] = self.n_added_in_callback
|
stats["UserCuts: Added in callback"] = self.n_added_in_callback
|
||||||
if self.n_added_in_callback > 0:
|
if self.n_added_in_callback > 0:
|
||||||
logger.info(f"{self.n_added_in_callback} user cuts added in callback")
|
logger.info(f"{self.n_added_in_callback} user cuts added in callback")
|
||||||
|
|
||||||
# Delegate ML methods to self.dynamic
|
# Delegate ML methods to self.dynamic
|
||||||
# -------------------------------------------------------------------
|
# -------------------------------------------------------------------
|
||||||
@overrides
|
|
||||||
def sample_xy_old(
|
|
||||||
self,
|
|
||||||
instance: "Instance",
|
|
||||||
sample: TrainingSample,
|
|
||||||
) -> Tuple[Dict, Dict]:
|
|
||||||
return self.dynamic.sample_xy_old(instance, sample)
|
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def sample_xy(
|
def sample_xy(
|
||||||
self,
|
self,
|
||||||
instance: Optional[Instance],
|
instance: "Instance",
|
||||||
sample: Sample,
|
sample: Sample,
|
||||||
) -> Tuple[Dict, Dict]:
|
) -> Tuple[Dict, Dict]:
|
||||||
return self.dynamic.sample_xy(instance, sample)
|
return self.dynamic.sample_xy(instance, sample)
|
||||||
@@ -115,13 +107,13 @@ class UserCutsComponent(Component):
|
|||||||
def sample_predict(
|
def sample_predict(
|
||||||
self,
|
self,
|
||||||
instance: "Instance",
|
instance: "Instance",
|
||||||
sample: TrainingSample,
|
sample: Sample,
|
||||||
) -> List[Hashable]:
|
) -> List[Hashable]:
|
||||||
return self.dynamic.sample_predict(instance, sample)
|
return self.dynamic.sample_predict(instance, sample)
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def fit_old(self, training_instances: List["Instance"]) -> None:
|
def fit(self, training_instances: List["Instance"]) -> None:
|
||||||
self.dynamic.fit_old(training_instances)
|
self.dynamic.fit(training_instances)
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def fit_xy(
|
def fit_xy(
|
||||||
@@ -132,9 +124,9 @@ class UserCutsComponent(Component):
|
|||||||
self.dynamic.fit_xy(x, y)
|
self.dynamic.fit_xy(x, y)
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def sample_evaluate_old(
|
def sample_evaluate(
|
||||||
self,
|
self,
|
||||||
instance: "Instance",
|
instance: "Instance",
|
||||||
sample: TrainingSample,
|
sample: Sample,
|
||||||
) -> Dict[Hashable, Dict[str, float]]:
|
) -> Dict[Hashable, Dict[str, float]]:
|
||||||
return self.dynamic.sample_evaluate_old(instance, sample)
|
return self.dynamic.sample_evaluate(instance, sample)
|
||||||
|
|||||||
@@ -80,9 +80,10 @@ def test_usage(
|
|||||||
solver: LearningSolver,
|
solver: LearningSolver,
|
||||||
) -> None:
|
) -> None:
|
||||||
stats_before = solver.solve(stab_instance)
|
stats_before = solver.solve(stab_instance)
|
||||||
sample = stab_instance.training_data[0]
|
sample = stab_instance.samples[0]
|
||||||
assert sample.user_cuts_enforced is not None
|
assert sample.after_mip is not None
|
||||||
assert len(sample.user_cuts_enforced) > 0
|
assert sample.after_mip.extra is not None
|
||||||
|
assert len(sample.after_mip.extra["user_cuts_enforced"]) > 0
|
||||||
print(stats_before)
|
print(stats_before)
|
||||||
assert stats_before["UserCuts: Added ahead-of-time"] == 0
|
assert stats_before["UserCuts: Added ahead-of-time"] == 0
|
||||||
assert stats_before["UserCuts: Added in callback"] > 0
|
assert stats_before["UserCuts: Added in callback"] > 0
|
||||||
|
|||||||
@@ -67,6 +67,8 @@ def test_subtour() -> None:
|
|||||||
instance = TravelingSalesmanInstance(n_cities, distances)
|
instance = TravelingSalesmanInstance(n_cities, distances)
|
||||||
solver = LearningSolver()
|
solver = LearningSolver()
|
||||||
solver.solve(instance)
|
solver.solve(instance)
|
||||||
|
assert instance.samples[0].after_mip is not None
|
||||||
|
assert instance.samples[0].after_mip.extra is not None
|
||||||
lazy_enforced = instance.samples[0].after_mip.extra["lazy_enforced"]
|
lazy_enforced = instance.samples[0].after_mip.extra["lazy_enforced"]
|
||||||
assert lazy_enforced is not None
|
assert lazy_enforced is not None
|
||||||
assert len(lazy_enforced) > 0
|
assert len(lazy_enforced) > 0
|
||||||
|
|||||||
Reference in New Issue
Block a user