Refactor StaticLazy

This commit is contained in:
2021-04-12 10:05:17 -05:00
parent e6672a45a0
commit cb62345acf
2 changed files with 91 additions and 169 deletions

View File

@@ -52,26 +52,35 @@ class StaticLazyConstraintsComponent(Component):
self.n_iterations: int = 0
@overrides
def before_solve_mip_old(
def after_solve_mip(
self,
solver: "LearningSolver",
instance: "Instance",
model: Any,
stats: LearningSolveStats,
features: Features,
training_data: TrainingSample,
sample: Sample,
) -> None:
sample.after_mip.extra["lazy_enforced"] = self.enforced_cids
stats["LazyStatic: Restored"] = self.n_restored
stats["LazyStatic: Iterations"] = self.n_iterations
@overrides
def before_solve_mip(
self,
solver: "LearningSolver",
instance: "Instance",
model: Any,
stats: LearningSolveStats,
sample: Sample,
) -> None:
assert solver.internal_solver is not None
assert features.instance is not None
assert features.constraints is not None
logger.info("Predicting violated (static) lazy constraints...")
if features.instance.lazy_constraint_count == 0:
if sample.after_load.instance.lazy_constraint_count == 0:
logger.info("Instance does not have static lazy constraints. Skipping.")
self.enforced_cids = set(self.sample_predict(instance, training_data))
self.enforced_cids = set(self.sample_predict(sample))
logger.info("Moving lazy constraints to the pool...")
self.pool = {}
for (cid, cdict) in features.constraints.items():
for (cid, cdict) in sample.after_load.constraints.items():
if cdict.lazy and cid not in self.enforced_cids:
self.pool[cid] = cdict
solver.internal_solver.remove_constraint(cid)
@@ -86,18 +95,17 @@ class StaticLazyConstraintsComponent(Component):
self.n_iterations = 0
@overrides
def after_solve_mip_old(
def fit_xy(
self,
solver: "LearningSolver",
instance: "Instance",
model: Any,
stats: LearningSolveStats,
features: Features,
training_data: TrainingSample,
x: Dict[Hashable, np.ndarray],
y: Dict[Hashable, np.ndarray],
) -> None:
training_data.lazy_enforced = self.enforced_cids
stats["LazyStatic: Restored"] = self.n_restored
stats["LazyStatic: Iterations"] = self.n_iterations
for c in y.keys():
assert c in x
self.classifiers[c] = self.classifier_prototype.clone()
self.thresholds[c] = self.threshold_prototype.clone()
self.classifiers[c].fit(x[c], y[c])
self.thresholds[c].fit(self.classifiers[c], x[c], y[c])
@overrides
def iteration_cb(
@@ -120,6 +128,30 @@ class StaticLazyConstraintsComponent(Component):
) -> None:
self._check_and_add(solver)
def sample_predict(self, sample: Sample) -> List[Hashable]:
x, y, cids = self._sample_xy_with_cids(sample)
enforced_cids: List[Hashable] = []
for category in x.keys():
if category not in self.classifiers:
continue
npx = np.array(x[category])
proba = self.classifiers[category].predict_proba(npx)
thr = self.thresholds[category].predict(npx)
pred = list(proba[:, 1] > thr[1])
for (i, is_selected) in enumerate(pred):
if is_selected:
enforced_cids += [cids[category][i]]
return enforced_cids
@overrides
def sample_xy(
self,
_: Optional[Instance],
sample: Sample,
) -> Tuple[Dict[Hashable, List[List[float]]], Dict[Hashable, List[List[float]]]]:
x, y, _ = self._sample_xy_with_cids(sample)
return x, y
def _check_and_add(self, solver: "LearningSolver") -> bool:
assert solver.internal_solver is not None
logger.info("Finding violated lazy constraints...")
@@ -145,69 +177,16 @@ class StaticLazyConstraintsComponent(Component):
else:
return False
def sample_predict(
self,
instance: "Instance",
sample: TrainingSample,
) -> List[Hashable]:
assert instance.features.constraints is not None
x, y = self.sample_xy_old(instance, sample)
category_to_cids: Dict[Hashable, List[Hashable]] = {}
for (cid, cfeatures) in instance.features.constraints.items():
if cfeatures.category is None:
continue
category = cfeatures.category
if category not in category_to_cids:
category_to_cids[category] = []
category_to_cids[category] += [cid]
enforced_cids: List[Hashable] = []
for category in x.keys():
if category not in self.classifiers:
continue
npx = np.array(x[category])
proba = self.classifiers[category].predict_proba(npx)
thr = self.thresholds[category].predict(npx)
pred = list(proba[:, 1] > thr[1])
for (i, is_selected) in enumerate(pred):
if is_selected:
enforced_cids += [category_to_cids[category][i]]
return enforced_cids
@overrides
def sample_xy_old(
self,
instance: "Instance",
sample: TrainingSample,
) -> Tuple[Dict[Hashable, List[List[float]]], Dict[Hashable, List[List[float]]]]:
assert instance.features.constraints is not None
x: Dict = {}
y: Dict = {}
for (cid, cfeatures) in instance.features.constraints.items():
if not cfeatures.lazy:
continue
category = cfeatures.category
if category is None:
continue
if category not in x:
x[category] = []
y[category] = []
x[category] += [cfeatures.user_features]
if sample.lazy_enforced is not None:
if cid in sample.lazy_enforced:
y[category] += [[False, True]]
else:
y[category] += [[True, False]]
return x, y
@overrides
def sample_xy(
self,
_: Optional[Instance],
sample: Sample,
) -> Tuple[Dict[Hashable, List[List[float]]], Dict[Hashable, List[List[float]]]]:
x: Dict = {}
y: Dict = {}
def _sample_xy_with_cids(
self, sample: Sample
) -> Tuple[
Dict[Hashable, List[List[float]]],
Dict[Hashable, List[List[float]]],
Dict[Hashable, List[str]],
]:
x: Dict[Hashable, List[List[float]]] = {}
y: Dict[Hashable, List[List[float]]] = {}
cids: Dict[Hashable, List[str]] = {}
assert sample.after_load is not None
assert sample.after_load.constraints is not None
for (cid, constr) in sample.after_load.constraints.items():
@@ -220,6 +199,7 @@ class StaticLazyConstraintsComponent(Component):
if category not in x:
x[category] = []
y[category] = []
cids[category] = []
# Features
sf = sample.after_load
@@ -231,25 +211,16 @@ class StaticLazyConstraintsComponent(Component):
assert sf.constraints[cid] is not None
features.extend(sf.constraints[cid].to_list())
x[category].append(features)
cids[category].append(cid)
# Labels
if sample.after_mip is not None:
assert sample.after_mip.extra is not None
if (
(sample.after_mip is not None)
and (sample.after_mip.extra is not None)
and ("lazy_enforced" in sample.after_mip.extra)
):
if cid in sample.after_mip.extra["lazy_enforced"]:
y[category] += [[False, True]]
else:
y[category] += [[True, False]]
return x, y
@overrides
def fit_xy(
self,
x: Dict[Hashable, np.ndarray],
y: Dict[Hashable, np.ndarray],
) -> None:
for c in y.keys():
assert c in x
self.classifiers[c] = self.classifier_prototype.clone()
self.thresholds[c] = self.threshold_prototype.clone()
self.classifiers[c].fit(x[c], y[c])
self.thresholds[c].fit(self.classifiers[c], x[c], y[c])
return x, y, cids