Update UserCutsComponent

master
Alinson S. Xavier 5 years ago
parent a4433916e5
commit c26b852c67
No known key found for this signature in database
GPG Key ID: DCA0DAD4D2F58624

@ -41,7 +41,7 @@ class DynamicLazyConstraintsComponent(Component):
self.classifiers = self.dynamic.classifiers self.classifiers = self.dynamic.classifiers
self.thresholds = self.dynamic.thresholds self.thresholds = self.dynamic.thresholds
self.known_cids = self.dynamic.known_cids self.known_cids = self.dynamic.known_cids
self.lazy_enforced: Set[str] = set() self.lazy_enforced: Set[Hashable] = set()
@staticmethod @staticmethod
def enforce( def enforce(

@ -38,20 +38,19 @@ class UserCutsComponent(Component):
self.n_added_in_callback = 0 self.n_added_in_callback = 0
@overrides @overrides
def before_solve_mip_old( def before_solve_mip(
self, self,
solver: "LearningSolver", solver: "LearningSolver",
instance: "Instance", instance: "Instance",
model: Any, model: Any,
stats: LearningSolveStats, stats: LearningSolveStats,
features: Features, sample: Sample,
training_data: TrainingSample,
) -> None: ) -> None:
assert solver.internal_solver is not None assert solver.internal_solver is not None
self.enforced.clear() self.enforced.clear()
self.n_added_in_callback = 0 self.n_added_in_callback = 0
logger.info("Predicting violated user cuts...") logger.info("Predicting violated user cuts...")
cids = self.dynamic.sample_predict_old(instance, training_data) cids = self.dynamic.sample_predict(instance, sample)
logger.info("Enforcing %d user cuts ahead-of-time..." % len(cids)) logger.info("Enforcing %d user cuts ahead-of-time..." % len(cids))
for cid in cids: for cid in cids:
instance.enforce_user_cut(solver.internal_solver, model, cid) instance.enforce_user_cut(solver.internal_solver, model, cid)
@ -80,34 +79,27 @@ class UserCutsComponent(Component):
logger.debug(f"Added {len(cids)} violated user cuts") logger.debug(f"Added {len(cids)} violated user cuts")
@overrides @overrides
def after_solve_mip_old( def after_solve_mip(
self, self,
solver: "LearningSolver", solver: "LearningSolver",
instance: "Instance", instance: "Instance",
model: Any, model: Any,
stats: LearningSolveStats, stats: LearningSolveStats,
features: Features, sample: Sample,
training_data: TrainingSample,
) -> None: ) -> None:
training_data.user_cuts_enforced = set(self.enforced) assert sample.after_mip is not None
assert sample.after_mip.extra is not None
sample.after_mip.extra["user_cuts_enforced"] = set(self.enforced)
stats["UserCuts: Added in callback"] = self.n_added_in_callback stats["UserCuts: Added in callback"] = self.n_added_in_callback
if self.n_added_in_callback > 0: if self.n_added_in_callback > 0:
logger.info(f"{self.n_added_in_callback} user cuts added in callback") logger.info(f"{self.n_added_in_callback} user cuts added in callback")
# Delegate ML methods to self.dynamic # Delegate ML methods to self.dynamic
# ------------------------------------------------------------------- # -------------------------------------------------------------------
@overrides
def sample_xy_old(
self,
instance: "Instance",
sample: TrainingSample,
) -> Tuple[Dict, Dict]:
return self.dynamic.sample_xy_old(instance, sample)
@overrides @overrides
def sample_xy( def sample_xy(
self, self,
instance: Optional[Instance], instance: "Instance",
sample: Sample, sample: Sample,
) -> Tuple[Dict, Dict]: ) -> Tuple[Dict, Dict]:
return self.dynamic.sample_xy(instance, sample) return self.dynamic.sample_xy(instance, sample)
@ -115,13 +107,13 @@ class UserCutsComponent(Component):
def sample_predict( def sample_predict(
self, self,
instance: "Instance", instance: "Instance",
sample: TrainingSample, sample: Sample,
) -> List[Hashable]: ) -> List[Hashable]:
return self.dynamic.sample_predict(instance, sample) return self.dynamic.sample_predict(instance, sample)
@overrides @overrides
def fit_old(self, training_instances: List["Instance"]) -> None: def fit(self, training_instances: List["Instance"]) -> None:
self.dynamic.fit_old(training_instances) self.dynamic.fit(training_instances)
@overrides @overrides
def fit_xy( def fit_xy(
@ -132,9 +124,9 @@ class UserCutsComponent(Component):
self.dynamic.fit_xy(x, y) self.dynamic.fit_xy(x, y)
@overrides @overrides
def sample_evaluate_old( def sample_evaluate(
self, self,
instance: "Instance", instance: "Instance",
sample: TrainingSample, sample: Sample,
) -> Dict[Hashable, Dict[str, float]]: ) -> Dict[Hashable, Dict[str, float]]:
return self.dynamic.sample_evaluate_old(instance, sample) return self.dynamic.sample_evaluate(instance, sample)

@ -80,9 +80,10 @@ def test_usage(
solver: LearningSolver, solver: LearningSolver,
) -> None: ) -> None:
stats_before = solver.solve(stab_instance) stats_before = solver.solve(stab_instance)
sample = stab_instance.training_data[0] sample = stab_instance.samples[0]
assert sample.user_cuts_enforced is not None assert sample.after_mip is not None
assert len(sample.user_cuts_enforced) > 0 assert sample.after_mip.extra is not None
assert len(sample.after_mip.extra["user_cuts_enforced"]) > 0
print(stats_before) print(stats_before)
assert stats_before["UserCuts: Added ahead-of-time"] == 0 assert stats_before["UserCuts: Added ahead-of-time"] == 0
assert stats_before["UserCuts: Added in callback"] > 0 assert stats_before["UserCuts: Added in callback"] > 0

@ -67,6 +67,8 @@ def test_subtour() -> None:
instance = TravelingSalesmanInstance(n_cities, distances) instance = TravelingSalesmanInstance(n_cities, distances)
solver = LearningSolver() solver = LearningSolver()
solver.solve(instance) solver.solve(instance)
assert instance.samples[0].after_mip is not None
assert instance.samples[0].after_mip.extra is not None
lazy_enforced = instance.samples[0].after_mip.extra["lazy_enforced"] lazy_enforced = instance.samples[0].after_mip.extra["lazy_enforced"]
assert lazy_enforced is not None assert lazy_enforced is not None
assert len(lazy_enforced) > 0 assert len(lazy_enforced) > 0

Loading…
Cancel
Save