Move user_cuts/lazy_enforced to sample.data

This commit is contained in:
2021-07-01 08:37:31 -05:00
parent 80281df8d8
commit 061b1349fe
10 changed files with 56 additions and 66 deletions

View File

@@ -81,13 +81,12 @@ class DynamicConstraintsComponent(Component):
cids[category].append(cid)
# Labels
if sample.after_mip is not None:
assert sample.after_mip.extra is not None
if sample.after_mip.extra[self.attr] is not None:
if cid in sample.after_mip.extra[self.attr]:
y[category] += [[False, True]]
else:
y[category] += [[True, False]]
enforced_cids = sample.get(self.attr)
if enforced_cids is not None:
if cid in enforced_cids:
y[category] += [[False, True]]
else:
y[category] += [[True, False]]
return x, y, cids
@overrides
@@ -133,13 +132,7 @@ class DynamicConstraintsComponent(Component):
@overrides
def pre_sample_xy(self, instance: Instance, sample: Sample) -> Any:
if (
sample.after_mip is None
or sample.after_mip.extra is None
or sample.after_mip.extra[self.attr] is None
):
return
return sample.after_mip.extra[self.attr]
return sample.get(self.attr)
@overrides
def fit_xy(
@@ -161,10 +154,8 @@ class DynamicConstraintsComponent(Component):
instance: Instance,
sample: Sample,
) -> Dict[Hashable, Dict[str, float]]:
assert sample.after_mip is not None
assert sample.after_mip.extra is not None
assert self.attr in sample.after_mip.extra
actual = sample.after_mip.extra[self.attr]
actual = sample.get(self.attr)
assert actual is not None
pred = set(self.sample_predict(instance, sample))
tp: Dict[Hashable, int] = {}
tn: Dict[Hashable, int] = {}

View File

@@ -78,9 +78,7 @@ class DynamicLazyConstraintsComponent(Component):
stats: LearningSolveStats,
sample: Sample,
) -> None:
assert sample.after_mip is not None
assert sample.after_mip.extra is not None
sample.after_mip.extra["lazy_enforced"] = set(self.lazy_enforced)
sample.put("lazy_enforced", set(self.lazy_enforced))
@overrides
def iteration_cb(

View File

@@ -87,9 +87,7 @@ class UserCutsComponent(Component):
stats: LearningSolveStats,
sample: Sample,
) -> None:
assert sample.after_mip is not None
assert sample.after_mip.extra is not None
sample.after_mip.extra["user_cuts_enforced"] = set(self.enforced)
sample.put("user_cuts_enforced", set(self.enforced))
stats["UserCuts: Added in callback"] = self.n_added_in_callback
if self.n_added_in_callback > 0:
logger.info(f"{self.n_added_in_callback} user cuts added in callback")

View File

@@ -60,9 +60,7 @@ class StaticLazyConstraintsComponent(Component):
stats: LearningSolveStats,
sample: Sample,
) -> None:
assert sample.after_mip is not None
assert sample.after_mip.extra is not None
sample.after_mip.extra["lazy_enforced"] = self.enforced_cids
sample.put("lazy_enforced", self.enforced_cids)
stats["LazyStatic: Restored"] = self.n_restored
stats["LazyStatic: Iterations"] = self.n_iterations
@@ -236,12 +234,9 @@ class StaticLazyConstraintsComponent(Component):
cids[category].append(cname)
# Labels
if (
(sample.after_mip is not None)
and (sample.after_mip.extra is not None)
and ("lazy_enforced" in sample.after_mip.extra)
):
if cname in sample.after_mip.extra["lazy_enforced"]:
lazy_enforced = sample.get("lazy_enforced")
if lazy_enforced is not None:
if cname in lazy_enforced:
y[category] += [[False, True]]
else:
y[category] += [[True, False]]