mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Lazy: Rename fields
This commit is contained in:
@@ -58,11 +58,11 @@ class BasicCollector:
|
|||||||
|
|
||||||
# Add lazy constraints to model
|
# Add lazy constraints to model
|
||||||
if (
|
if (
|
||||||
hasattr(model, "fix_violations")
|
hasattr(model, "lazy_enforce")
|
||||||
and model.fix_violations is not None
|
and model.lazy_enforce is not None
|
||||||
):
|
):
|
||||||
model.fix_violations(model, model.violations_, "aot")
|
model.lazy_enforce(model, model.lazy_constrs_, "aot")
|
||||||
h5.put_scalar("mip_constr_violations", repr(model.violations_))
|
h5.put_scalar("mip_lazy", repr(model.lazy_constrs_))
|
||||||
|
|
||||||
# Save MPS file
|
# Save MPS file
|
||||||
model.write(mps_filename)
|
model.write(mps_filename)
|
||||||
|
|||||||
@@ -24,30 +24,30 @@ class MemorizingLazyConstrComponent:
|
|||||||
def __init__(self, clf: Any, extractor: FeaturesExtractor) -> None:
|
def __init__(self, clf: Any, extractor: FeaturesExtractor) -> None:
|
||||||
self.clf = clf
|
self.clf = clf
|
||||||
self.extractor = extractor
|
self.extractor = extractor
|
||||||
self.violations_: List[Hashable] = []
|
self.constrs_: List[Hashable] = []
|
||||||
self.n_features_: int = 0
|
self.n_features_: int = 0
|
||||||
self.n_targets_: int = 0
|
self.n_targets_: int = 0
|
||||||
|
|
||||||
def fit(self, train_h5: List[str]) -> None:
|
def fit(self, train_h5: List[str]) -> None:
|
||||||
logger.info("Reading training data...")
|
logger.info("Reading training data...")
|
||||||
n_samples = len(train_h5)
|
n_samples = len(train_h5)
|
||||||
x, y, violations, n_features = [], [], [], None
|
x, y, constrs, n_features = [], [], [], None
|
||||||
violation_to_idx: Dict[Hashable, int] = {}
|
constr_to_idx: Dict[Hashable, int] = {}
|
||||||
for h5_filename in train_h5:
|
for h5_filename in train_h5:
|
||||||
with H5File(h5_filename, "r") as h5:
|
with H5File(h5_filename, "r") as h5:
|
||||||
|
|
||||||
# Store lazy constraints
|
# Store lazy constraints
|
||||||
sample_violations_str = h5.get_scalar("mip_constr_violations")
|
sample_constrs_str = h5.get_scalar("mip_lazy")
|
||||||
assert sample_violations_str is not None
|
assert sample_constrs_str is not None
|
||||||
assert isinstance(sample_violations_str, str)
|
assert isinstance(sample_constrs_str, str)
|
||||||
sample_violations = eval(sample_violations_str)
|
sample_constrs = eval(sample_constrs_str)
|
||||||
assert isinstance(sample_violations, list)
|
assert isinstance(sample_constrs, list)
|
||||||
y_sample = []
|
y_sample = []
|
||||||
for v in sample_violations:
|
for c in sample_constrs:
|
||||||
if v not in violation_to_idx:
|
if c not in constr_to_idx:
|
||||||
violation_to_idx[v] = len(violation_to_idx)
|
constr_to_idx[c] = len(constr_to_idx)
|
||||||
violations.append(v)
|
constrs.append(c)
|
||||||
y_sample.append(violation_to_idx[v])
|
y_sample.append(constr_to_idx[c])
|
||||||
y.append(y_sample)
|
y.append(y_sample)
|
||||||
|
|
||||||
# Extract features
|
# Extract features
|
||||||
@@ -62,8 +62,8 @@ class MemorizingLazyConstrComponent:
|
|||||||
logger.info("Constructing matrices...")
|
logger.info("Constructing matrices...")
|
||||||
assert n_features is not None
|
assert n_features is not None
|
||||||
self.n_features_ = n_features
|
self.n_features_ = n_features
|
||||||
self.violations_ = violations
|
self.constrs_ = constrs
|
||||||
self.n_targets_ = len(violation_to_idx)
|
self.n_targets_ = len(constr_to_idx)
|
||||||
x_np = np.vstack(x)
|
x_np = np.vstack(x)
|
||||||
assert x_np.shape == (n_samples, n_features)
|
assert x_np.shape == (n_samples, n_features)
|
||||||
y_np = MultiLabelBinarizer().fit_transform(y)
|
y_np = MultiLabelBinarizer().fit_transform(y)
|
||||||
@@ -82,8 +82,8 @@ class MemorizingLazyConstrComponent:
|
|||||||
model: GurobiModel,
|
model: GurobiModel,
|
||||||
stats: Dict[str, Any],
|
stats: Dict[str, Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
assert self.violations_ is not None
|
assert self.constrs_ is not None
|
||||||
if model.fix_violations is None:
|
if model.lazy_enforce is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Read features
|
# Read features
|
||||||
@@ -99,7 +99,7 @@ class MemorizingLazyConstrComponent:
|
|||||||
y = y.reshape(-1)
|
y = y.reshape(-1)
|
||||||
|
|
||||||
# Enforce constraints
|
# Enforce constraints
|
||||||
violations = [self.violations_[i] for (i, yi) in enumerate(y) if yi > 0.5]
|
violations = [self.constrs_[i] for (i, yi) in enumerate(y) if yi > 0.5]
|
||||||
logger.info(f"Enforcing {len(violations)} constraints ahead-of-time...")
|
logger.info(f"Enforcing {len(violations)} constraints ahead-of-time...")
|
||||||
model.fix_violations(model, violations, "aot")
|
model.lazy_enforce(model, violations, "aot")
|
||||||
stats["Lazy Constraints: AOT"] = len(violations)
|
stats["Lazy Constraints: AOT"] = len(violations)
|
||||||
|
|||||||
@@ -142,7 +142,7 @@ def build_tsp_model(data: Union[str, TravelingSalesmanData]) -> GurobiModel:
|
|||||||
name="eq_degree",
|
name="eq_degree",
|
||||||
)
|
)
|
||||||
|
|
||||||
def find_violations(model: GurobiModel) -> List[Any]:
|
def lazy_separate(model: GurobiModel) -> List[Any]:
|
||||||
violations = []
|
violations = []
|
||||||
x = model.inner.cbGetSolution(model.inner._x)
|
x = model.inner.cbGetSolution(model.inner._x)
|
||||||
selected_edges = [e for e in model.inner._edges if x[e] > 0.5]
|
selected_edges = [e for e in model.inner._edges if x[e] > 0.5]
|
||||||
@@ -159,7 +159,7 @@ def build_tsp_model(data: Union[str, TravelingSalesmanData]) -> GurobiModel:
|
|||||||
violations.append(cut_edges)
|
violations.append(cut_edges)
|
||||||
return violations
|
return violations
|
||||||
|
|
||||||
def fix_violations(model: GurobiModel, violations: List[Any], where: str) -> None:
|
def lazy_enforce(model: GurobiModel, violations: List[Any], where: str) -> None:
|
||||||
for violation in violations:
|
for violation in violations:
|
||||||
constr = quicksum(model.inner._x[e[0], e[1]] for e in violation) >= 2
|
constr = quicksum(model.inner._x[e[0], e[1]] for e in violation) >= 2
|
||||||
if where == "cb":
|
if where == "cb":
|
||||||
@@ -172,6 +172,6 @@ def build_tsp_model(data: Union[str, TravelingSalesmanData]) -> GurobiModel:
|
|||||||
|
|
||||||
return GurobiModel(
|
return GurobiModel(
|
||||||
model,
|
model,
|
||||||
find_violations=find_violations,
|
lazy_separate=lazy_separate,
|
||||||
fix_violations=fix_violations,
|
lazy_enforce=lazy_enforce,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -21,13 +21,13 @@ class GurobiModel(AbstractModel):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
inner: gp.Model,
|
inner: gp.Model,
|
||||||
find_violations: Optional[Callable] = None,
|
lazy_separate: Optional[Callable] = None,
|
||||||
fix_violations: Optional[Callable] = None,
|
lazy_enforce: Optional[Callable] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.fix_violations = fix_violations
|
self.lazy_separate = lazy_separate
|
||||||
self.find_violations = find_violations
|
self.lazy_enforce = lazy_enforce
|
||||||
self.inner = inner
|
self.inner = inner
|
||||||
self.violations_: Optional[List[Any]] = None
|
self.lazy_constrs_: Optional[List[Any]] = None
|
||||||
|
|
||||||
def add_constrs(
|
def add_constrs(
|
||||||
self,
|
self,
|
||||||
@@ -125,18 +125,18 @@ class GurobiModel(AbstractModel):
|
|||||||
stats["Fixed variables"] = n_fixed
|
stats["Fixed variables"] = n_fixed
|
||||||
|
|
||||||
def optimize(self) -> None:
|
def optimize(self) -> None:
|
||||||
self.violations_ = []
|
self.lazy_constrs_ = []
|
||||||
|
|
||||||
def callback(m: gp.Model, where: int) -> None:
|
def callback(m: gp.Model, where: int) -> None:
|
||||||
assert self.find_violations is not None
|
assert self.lazy_separate is not None
|
||||||
assert self.violations_ is not None
|
assert self.lazy_constrs_ is not None
|
||||||
assert self.fix_violations is not None
|
assert self.lazy_enforce is not None
|
||||||
if where == GRB.Callback.MIPSOL:
|
if where == GRB.Callback.MIPSOL:
|
||||||
violations = self.find_violations(self)
|
violations = self.lazy_separate(self)
|
||||||
self.violations_.extend(violations)
|
self.lazy_constrs_.extend(violations)
|
||||||
self.fix_violations(self, violations, "cb")
|
self.lazy_enforce(self, violations, "cb")
|
||||||
|
|
||||||
if self.fix_violations is not None:
|
if self.lazy_enforce is not None:
|
||||||
self.inner.Params.lazyConstraints = 1
|
self.inner.Params.lazyConstraints = 1
|
||||||
self.inner.optimize(callback)
|
self.inner.optimize(callback)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -33,10 +33,10 @@ def test_mem_component(
|
|||||||
]
|
]
|
||||||
|
|
||||||
# Should store violations
|
# Should store violations
|
||||||
assert comp.violations_ is not None
|
assert comp.constrs_ is not None
|
||||||
assert comp.n_features_ == 190
|
assert comp.n_features_ == 190
|
||||||
assert comp.n_targets_ == 22
|
assert comp.n_targets_ == 22
|
||||||
assert len(comp.violations_) == 22
|
assert len(comp.constrs_) == 22
|
||||||
|
|
||||||
# Call before-mip
|
# Call before-mip
|
||||||
stats: Dict[str, Any] = {}
|
stats: Dict[str, Any] = {}
|
||||||
|
|||||||
BIN
tests/fixtures/tsp-n20-00000.h5
vendored
BIN
tests/fixtures/tsp-n20-00000.h5
vendored
Binary file not shown.
BIN
tests/fixtures/tsp-n20-00000.mps.gz
vendored
BIN
tests/fixtures/tsp-n20-00000.mps.gz
vendored
Binary file not shown.
BIN
tests/fixtures/tsp-n20-00000.pkl.gz
vendored
BIN
tests/fixtures/tsp-n20-00000.pkl.gz
vendored
Binary file not shown.
BIN
tests/fixtures/tsp-n20-00001.h5
vendored
BIN
tests/fixtures/tsp-n20-00001.h5
vendored
Binary file not shown.
BIN
tests/fixtures/tsp-n20-00001.mps.gz
vendored
BIN
tests/fixtures/tsp-n20-00001.mps.gz
vendored
Binary file not shown.
BIN
tests/fixtures/tsp-n20-00001.pkl.gz
vendored
BIN
tests/fixtures/tsp-n20-00001.pkl.gz
vendored
Binary file not shown.
BIN
tests/fixtures/tsp-n20-00002.h5
vendored
BIN
tests/fixtures/tsp-n20-00002.h5
vendored
Binary file not shown.
BIN
tests/fixtures/tsp-n20-00002.mps.gz
vendored
BIN
tests/fixtures/tsp-n20-00002.mps.gz
vendored
Binary file not shown.
BIN
tests/fixtures/tsp-n20-00002.pkl.gz
vendored
BIN
tests/fixtures/tsp-n20-00002.pkl.gz
vendored
Binary file not shown.
Reference in New Issue
Block a user