mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
StaticLazy: Refactor
This commit is contained in:
@@ -151,8 +151,8 @@ class Component:
|
||||
|
||||
def fit_xy(
|
||||
self,
|
||||
x: Dict[str, np.ndarray],
|
||||
y: Dict[str, np.ndarray],
|
||||
x: Dict[Hashable, np.ndarray],
|
||||
y: Dict[Hashable, np.ndarray],
|
||||
) -> None:
|
||||
"""
|
||||
Given two dictionaries x and y, mapping the name of the category to matrices
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from typing import Dict, Tuple, Optional
|
||||
from typing import Dict, Tuple, Optional, List, Hashable, Any, TYPE_CHECKING, Set
|
||||
|
||||
import numpy as np
|
||||
from tqdm.auto import tqdm
|
||||
@@ -12,203 +12,163 @@ from tqdm.auto import tqdm
|
||||
from miplearn import Classifier
|
||||
from miplearn.classifiers.counting import CountingClassifier
|
||||
from miplearn.components.component import Component
|
||||
from miplearn.types import TrainingSample, Features
|
||||
from miplearn.types import TrainingSample, Features, LearningSolveStats
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from miplearn.solvers.learning import LearningSolver, Instance
|
||||
|
||||
|
||||
class LazyConstraint:
|
||||
def __init__(self, cid, obj):
|
||||
def __init__(self, cid: str, obj: Any) -> None:
|
||||
self.cid = cid
|
||||
self.obj = obj
|
||||
|
||||
|
||||
class StaticLazyConstraintsComponent(Component):
|
||||
"""
|
||||
Component that decides which of the constraints tagged as lazy should
|
||||
be kept in the formulation, and which should be removed.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
classifier=CountingClassifier(),
|
||||
threshold=0.05,
|
||||
use_two_phase_gap=True,
|
||||
large_gap=1e-2,
|
||||
violation_tolerance=-0.5,
|
||||
):
|
||||
classifier: Classifier = CountingClassifier(),
|
||||
threshold: float = 0.05,
|
||||
violation_tolerance: float = -0.5,
|
||||
) -> None:
|
||||
assert isinstance(classifier, Classifier)
|
||||
self.threshold = threshold
|
||||
self.classifier_prototype = classifier
|
||||
self.classifiers = {}
|
||||
self.pool = []
|
||||
self.original_gap = None
|
||||
self.large_gap = large_gap
|
||||
self.is_gap_large = False
|
||||
self.use_two_phase_gap = use_two_phase_gap
|
||||
self.violation_tolerance = violation_tolerance
|
||||
self.threshold: float = threshold
|
||||
self.classifier_prototype: Classifier = classifier
|
||||
self.classifiers: Dict[Hashable, Classifier] = {}
|
||||
self.pool: Dict[str, LazyConstraint] = {}
|
||||
self.violation_tolerance: float = violation_tolerance
|
||||
self.enforced_cids: Set[str] = set()
|
||||
self.n_restored: int = 0
|
||||
self.n_iterations: int = 0
|
||||
|
||||
def before_solve_mip(
|
||||
self,
|
||||
solver,
|
||||
instance,
|
||||
model,
|
||||
stats,
|
||||
features,
|
||||
training_data,
|
||||
):
|
||||
self.pool = []
|
||||
if not solver.use_lazy_cb and self.use_two_phase_gap:
|
||||
logger.info("Increasing gap tolerance to %f", self.large_gap)
|
||||
self.original_gap = solver.gap_tolerance
|
||||
self.is_gap_large = True
|
||||
solver.internal_solver.set_gap_tolerance(self.large_gap)
|
||||
solver: "LearningSolver",
|
||||
instance: "Instance",
|
||||
model: Any,
|
||||
stats: LearningSolveStats,
|
||||
features: Features,
|
||||
training_data: TrainingSample,
|
||||
) -> None:
|
||||
assert solver.internal_solver is not None
|
||||
if not features["Instance"]["Lazy constraint count"] == 0:
|
||||
logger.info("Instance does not have static lazy constraints. Skipping.")
|
||||
logger.info("Predicting required lazy constraints...")
|
||||
self.enforced_cids = set(self.sample_predict(features, training_data))
|
||||
logger.info("Moving lazy constraints to the pool...")
|
||||
self.pool = {}
|
||||
for (cid, cdict) in features["Constraints"].items():
|
||||
if cdict["Lazy"] and cid not in self.enforced_cids:
|
||||
self.pool[cid] = LazyConstraint(
|
||||
cid=cid,
|
||||
obj=solver.internal_solver.extract_constraint(cid),
|
||||
)
|
||||
logger.info(
|
||||
f"{len(self.enforced_cids)} lazy constraints kept; "
|
||||
f"{len(self.pool)} moved to the pool"
|
||||
)
|
||||
stats["LazyStatic: Removed"] = len(self.pool)
|
||||
stats["LazyStatic: Kept"] = len(self.enforced_cids)
|
||||
stats["LazyStatic: Restored"] = 0
|
||||
self.n_restored = 0
|
||||
self.n_iterations = 0
|
||||
|
||||
instance.found_violated_lazy_constraints = []
|
||||
if instance.has_static_lazy_constraints():
|
||||
self._extract_and_predict_static(solver, instance)
|
||||
def after_solve_mip(
|
||||
self,
|
||||
solver: "LearningSolver",
|
||||
instance: "Instance",
|
||||
model: Any,
|
||||
stats: LearningSolveStats,
|
||||
features: Features,
|
||||
training_data: TrainingSample,
|
||||
) -> None:
|
||||
training_data["LazyStatic: Enforced"] = self.enforced_cids
|
||||
stats["LazyStatic: Restored"] = self.n_restored
|
||||
stats["LazyStatic: Iterations"] = self.n_iterations
|
||||
|
||||
def iteration_cb(self, solver, instance, model):
|
||||
def iteration_cb(
|
||||
self,
|
||||
solver: "LearningSolver",
|
||||
instance: "Instance",
|
||||
model: Any,
|
||||
) -> bool:
|
||||
if solver.use_lazy_cb:
|
||||
return False
|
||||
else:
|
||||
should_repeat = self._check_and_add(instance, solver)
|
||||
if should_repeat:
|
||||
return True
|
||||
else:
|
||||
if self.is_gap_large:
|
||||
logger.info("Restoring gap tolerance to %f", self.original_gap)
|
||||
solver.internal_solver.set_gap_tolerance(self.original_gap)
|
||||
self.is_gap_large = False
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
return self._check_and_add(solver)
|
||||
|
||||
def lazy_cb(self, solver, instance, model):
|
||||
self._check_and_add(instance, solver)
|
||||
def lazy_cb(
|
||||
self,
|
||||
solver: "LearningSolver",
|
||||
instance: "Instance",
|
||||
model: Any,
|
||||
) -> None:
|
||||
self._check_and_add(solver)
|
||||
|
||||
def _check_and_add(self, instance, solver):
|
||||
logger.debug("Finding violated lazy constraints...")
|
||||
constraints_to_add = []
|
||||
for c in self.pool:
|
||||
def _check_and_add(self, solver: "LearningSolver") -> bool:
|
||||
assert solver.internal_solver is not None
|
||||
logger.info("Finding violated lazy constraints...")
|
||||
enforced: List[LazyConstraint] = []
|
||||
for (cid, c) in self.pool.items():
|
||||
if not solver.internal_solver.is_constraint_satisfied(
|
||||
c.obj, tol=self.violation_tolerance
|
||||
c.obj,
|
||||
tol=self.violation_tolerance,
|
||||
):
|
||||
constraints_to_add.append(c)
|
||||
for c in constraints_to_add:
|
||||
self.pool.remove(c)
|
||||
enforced.append(c)
|
||||
logger.info(f"{len(enforced)} violations found")
|
||||
for c in enforced:
|
||||
del self.pool[c.cid]
|
||||
solver.internal_solver.add_constraint(c.obj)
|
||||
instance.found_violated_lazy_constraints += [c.cid]
|
||||
if len(constraints_to_add) > 0:
|
||||
logger.info(
|
||||
"%8d lazy constraints added %8d in the pool"
|
||||
% (len(constraints_to_add), len(self.pool))
|
||||
)
|
||||
self.enforced_cids.add(c.cid)
|
||||
self.n_restored += 1
|
||||
logger.info(
|
||||
f"{len(enforced)} constraints restored; {len(self.pool)} in the pool"
|
||||
)
|
||||
if len(enforced) > 0:
|
||||
self.n_iterations += 1
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def fit(self, training_instances):
|
||||
training_instances = [
|
||||
t
|
||||
for t in training_instances
|
||||
if hasattr(t, "found_violated_lazy_constraints")
|
||||
]
|
||||
|
||||
logger.debug("Extracting x and y...")
|
||||
x = self.x(training_instances)
|
||||
y = self.y(training_instances)
|
||||
|
||||
logger.debug("Fitting...")
|
||||
for category in tqdm(
|
||||
x.keys(), desc="Fit (lazy)", disable=not sys.stdout.isatty()
|
||||
):
|
||||
if category not in self.classifiers:
|
||||
self.classifiers[category] = self.classifier_prototype.clone()
|
||||
self.classifiers[category].fit(x[category], y[category])
|
||||
|
||||
def predict(self, instance):
|
||||
pass
|
||||
|
||||
def evaluate(self, instances):
|
||||
pass
|
||||
|
||||
def _extract_and_predict_static(self, solver, instance):
|
||||
x = {}
|
||||
constraints = {}
|
||||
logger.info("Extracting lazy constraints...")
|
||||
for cid in solver.internal_solver.get_constraint_ids():
|
||||
if instance.is_constraint_lazy(cid):
|
||||
category = instance.get_constraint_category(cid)
|
||||
if category not in x:
|
||||
x[category] = []
|
||||
constraints[category] = []
|
||||
x[category] += [instance.get_constraint_features(cid)]
|
||||
c = LazyConstraint(
|
||||
cid=cid,
|
||||
obj=solver.internal_solver.extract_constraint(cid),
|
||||
)
|
||||
constraints[category] += [c]
|
||||
self.pool.append(c)
|
||||
logger.info("%8d lazy constraints extracted" % len(self.pool))
|
||||
logger.info("Predicting required lazy constraints...")
|
||||
n_added = 0
|
||||
for (category, x_values) in x.items():
|
||||
def sample_predict(
|
||||
self,
|
||||
features: Features,
|
||||
sample: TrainingSample,
|
||||
) -> List[str]:
|
||||
x, y = self.sample_xy(features, sample)
|
||||
category_to_cids: Dict[Hashable, List[str]] = {}
|
||||
for (cid, cdict) in features["Constraints"].items():
|
||||
if "Category" not in cdict or cdict["Category"] is None:
|
||||
continue
|
||||
category = cdict["Category"]
|
||||
if category not in category_to_cids:
|
||||
category_to_cids[category] = []
|
||||
category_to_cids[category] += [cid]
|
||||
enforced_cids: List[str] = []
|
||||
for category in x.keys():
|
||||
if category not in self.classifiers:
|
||||
continue
|
||||
if isinstance(x_values[0], np.ndarray):
|
||||
x[category] = np.array(x_values)
|
||||
proba = self.classifiers[category].predict_proba(x[category])
|
||||
for i in range(len(proba)):
|
||||
if proba[i][1] > self.threshold:
|
||||
n_added += 1
|
||||
c = constraints[category][i]
|
||||
self.pool.remove(c)
|
||||
solver.internal_solver.add_constraint(c.obj)
|
||||
instance.found_violated_lazy_constraints += [c.cid]
|
||||
logger.info(
|
||||
"%8d lazy constraints added %8d in the pool"
|
||||
% (
|
||||
n_added,
|
||||
len(self.pool),
|
||||
)
|
||||
)
|
||||
|
||||
def _collect_constraints(self, train_instances):
|
||||
constraints = {}
|
||||
for instance in train_instances:
|
||||
for cid in instance.found_violated_lazy_constraints:
|
||||
category = instance.get_constraint_category(cid)
|
||||
if category not in constraints:
|
||||
constraints[category] = set()
|
||||
constraints[category].add(cid)
|
||||
for (category, cids) in constraints.items():
|
||||
constraints[category] = sorted(list(cids))
|
||||
return constraints
|
||||
|
||||
def x(self, train_instances):
|
||||
result = {}
|
||||
constraints = self._collect_constraints(train_instances)
|
||||
for (category, cids) in constraints.items():
|
||||
result[category] = []
|
||||
for instance in train_instances:
|
||||
for cid in cids:
|
||||
result[category].append(instance.get_constraint_features(cid))
|
||||
return result
|
||||
|
||||
def y(self, train_instances):
|
||||
result = {}
|
||||
constraints = self._collect_constraints(train_instances)
|
||||
for (category, cids) in constraints.items():
|
||||
result[category] = []
|
||||
for instance in train_instances:
|
||||
for cid in cids:
|
||||
if cid in instance.found_violated_lazy_constraints:
|
||||
result[category].append([0, 1])
|
||||
else:
|
||||
result[category].append([1, 0])
|
||||
return result
|
||||
clf = self.classifiers[category]
|
||||
proba = clf.predict_proba(np.array(x[category]))
|
||||
pred = list(proba[:, 1] > self.threshold)
|
||||
for (i, is_selected) in enumerate(pred):
|
||||
if is_selected:
|
||||
enforced_cids += [category_to_cids[category][i]]
|
||||
return enforced_cids
|
||||
|
||||
@staticmethod
|
||||
def sample_xy(
|
||||
features: Features,
|
||||
sample: TrainingSample,
|
||||
) -> Tuple[Dict, Dict]:
|
||||
) -> Tuple[Dict[Hashable, List[List[float]]], Dict[Hashable, List[List[float]]]]:
|
||||
x: Dict = {}
|
||||
y: Dict = {}
|
||||
for (cid, cfeatures) in features["Constraints"].items():
|
||||
@@ -227,3 +187,13 @@ class StaticLazyConstraintsComponent(Component):
|
||||
else:
|
||||
y[category] += [[True, False]]
|
||||
return x, y
|
||||
|
||||
def fit_xy(
|
||||
self,
|
||||
x: Dict[Hashable, np.ndarray],
|
||||
y: Dict[Hashable, np.ndarray],
|
||||
) -> None:
|
||||
for c in y.keys():
|
||||
assert c in x
|
||||
self.classifiers[c] = self.classifier_prototype.clone()
|
||||
self.classifiers[c].fit(x[c], y[c])
|
||||
|
||||
@@ -58,8 +58,8 @@ class ObjectiveValueComponent(Component):
|
||||
|
||||
def fit_xy(
|
||||
self,
|
||||
x: Dict[str, np.ndarray],
|
||||
y: Dict[str, np.ndarray],
|
||||
x: Dict[Hashable, np.ndarray],
|
||||
y: Dict[Hashable, np.ndarray],
|
||||
) -> None:
|
||||
for c in ["Upper bound", "Lower bound"]:
|
||||
if c in y:
|
||||
@@ -84,9 +84,9 @@ class ObjectiveValueComponent(Component):
|
||||
def sample_xy(
|
||||
features: Features,
|
||||
sample: TrainingSample,
|
||||
) -> Tuple[Dict[str, List[List[float]]], Dict[str, List[List[float]]]]:
|
||||
x: Dict[str, List[List[float]]] = {}
|
||||
y: Dict[str, List[List[float]]] = {}
|
||||
) -> Tuple[Dict[Hashable, List[List[float]]], Dict[Hashable, List[List[float]]]]:
|
||||
x: Dict[Hashable, List[List[float]]] = {}
|
||||
y: Dict[Hashable, List[List[float]]] = {}
|
||||
f = list(features["Instance"]["User features"])
|
||||
if "LP value" in sample and sample["LP value"] is not None:
|
||||
f += [sample["LP value"]]
|
||||
|
||||
@@ -148,7 +148,7 @@ class PrimalSolutionComponent(Component):
|
||||
def sample_xy(
|
||||
features: Features,
|
||||
sample: TrainingSample,
|
||||
) -> Tuple[Dict, Dict]:
|
||||
) -> Tuple[Dict[Hashable, List[List[float]]], Dict[Hashable, List[List[float]]]]:
|
||||
x: Dict = {}
|
||||
y: Dict = {}
|
||||
solution: Optional[Solution] = None
|
||||
@@ -227,8 +227,8 @@ class PrimalSolutionComponent(Component):
|
||||
|
||||
def fit_xy(
|
||||
self,
|
||||
x: Dict[str, np.ndarray],
|
||||
y: Dict[str, np.ndarray],
|
||||
x: Dict[Hashable, np.ndarray],
|
||||
y: Dict[Hashable, np.ndarray],
|
||||
) -> None:
|
||||
for category in x.keys():
|
||||
clf = self.classifier_prototype.clone()
|
||||
|
||||
@@ -20,11 +20,12 @@ class FeaturesExtractor:
|
||||
self.solver = internal_solver
|
||||
|
||||
def extract(self, instance: "Instance") -> Features:
|
||||
return {
|
||||
"Instance": self._extract_instance(instance),
|
||||
"Constraints": self._extract_constraints(instance),
|
||||
features: Features = {
|
||||
"Variables": self._extract_variables(instance),
|
||||
"Constraints": self._extract_constraints(instance),
|
||||
}
|
||||
features["Instance"] = self._extract_instance(instance, features)
|
||||
return features
|
||||
|
||||
def _extract_variables(self, instance: "Instance") -> Dict:
|
||||
variables = self.solver.get_empty_solution()
|
||||
@@ -92,7 +93,10 @@ class FeaturesExtractor:
|
||||
return constraints
|
||||
|
||||
@staticmethod
|
||||
def _extract_instance(instance: "Instance") -> InstanceFeatures:
|
||||
def _extract_instance(
|
||||
instance: "Instance",
|
||||
features: Features,
|
||||
) -> InstanceFeatures:
|
||||
user_features = instance.get_instance_features()
|
||||
assert isinstance(user_features, list), (
|
||||
f"Instance features must be a list. "
|
||||
@@ -103,4 +107,11 @@ class FeaturesExtractor:
|
||||
f"Instance features must be a list of numbers. "
|
||||
f"Found {type(v).__name__} instead."
|
||||
)
|
||||
return {"User features": user_features}
|
||||
lazy_count = 0
|
||||
for (cid, cdict) in features["Constraints"].items():
|
||||
if cdict["Lazy"]:
|
||||
lazy_count += 1
|
||||
return {
|
||||
"User features": user_features,
|
||||
"Lazy constraint count": lazy_count,
|
||||
}
|
||||
|
||||
@@ -69,6 +69,10 @@ LearningSolveStats = TypedDict(
|
||||
"Upper bound": Optional[float],
|
||||
"Wallclock time": float,
|
||||
"Warm start value": Optional[float],
|
||||
"LazyStatic: Removed": int,
|
||||
"LazyStatic: Kept": int,
|
||||
"LazyStatic: Restored": int,
|
||||
"LazyStatic: Iterations": int,
|
||||
},
|
||||
total=False,
|
||||
)
|
||||
@@ -77,6 +81,7 @@ InstanceFeatures = TypedDict(
|
||||
"InstanceFeatures",
|
||||
{
|
||||
"User features": List[float],
|
||||
"Lazy constraint count": int,
|
||||
},
|
||||
total=False,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user