mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Move user_cuts/lazy_enforced to sample.data
This commit is contained in:
@@ -81,10 +81,9 @@ class DynamicConstraintsComponent(Component):
|
|||||||
cids[category].append(cid)
|
cids[category].append(cid)
|
||||||
|
|
||||||
# Labels
|
# Labels
|
||||||
if sample.after_mip is not None:
|
enforced_cids = sample.get(self.attr)
|
||||||
assert sample.after_mip.extra is not None
|
if enforced_cids is not None:
|
||||||
if sample.after_mip.extra[self.attr] is not None:
|
if cid in enforced_cids:
|
||||||
if cid in sample.after_mip.extra[self.attr]:
|
|
||||||
y[category] += [[False, True]]
|
y[category] += [[False, True]]
|
||||||
else:
|
else:
|
||||||
y[category] += [[True, False]]
|
y[category] += [[True, False]]
|
||||||
@@ -133,13 +132,7 @@ class DynamicConstraintsComponent(Component):
|
|||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def pre_sample_xy(self, instance: Instance, sample: Sample) -> Any:
|
def pre_sample_xy(self, instance: Instance, sample: Sample) -> Any:
|
||||||
if (
|
return sample.get(self.attr)
|
||||||
sample.after_mip is None
|
|
||||||
or sample.after_mip.extra is None
|
|
||||||
or sample.after_mip.extra[self.attr] is None
|
|
||||||
):
|
|
||||||
return
|
|
||||||
return sample.after_mip.extra[self.attr]
|
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def fit_xy(
|
def fit_xy(
|
||||||
@@ -161,10 +154,8 @@ class DynamicConstraintsComponent(Component):
|
|||||||
instance: Instance,
|
instance: Instance,
|
||||||
sample: Sample,
|
sample: Sample,
|
||||||
) -> Dict[Hashable, Dict[str, float]]:
|
) -> Dict[Hashable, Dict[str, float]]:
|
||||||
assert sample.after_mip is not None
|
actual = sample.get(self.attr)
|
||||||
assert sample.after_mip.extra is not None
|
assert actual is not None
|
||||||
assert self.attr in sample.after_mip.extra
|
|
||||||
actual = sample.after_mip.extra[self.attr]
|
|
||||||
pred = set(self.sample_predict(instance, sample))
|
pred = set(self.sample_predict(instance, sample))
|
||||||
tp: Dict[Hashable, int] = {}
|
tp: Dict[Hashable, int] = {}
|
||||||
tn: Dict[Hashable, int] = {}
|
tn: Dict[Hashable, int] = {}
|
||||||
|
|||||||
@@ -78,9 +78,7 @@ class DynamicLazyConstraintsComponent(Component):
|
|||||||
stats: LearningSolveStats,
|
stats: LearningSolveStats,
|
||||||
sample: Sample,
|
sample: Sample,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert sample.after_mip is not None
|
sample.put("lazy_enforced", set(self.lazy_enforced))
|
||||||
assert sample.after_mip.extra is not None
|
|
||||||
sample.after_mip.extra["lazy_enforced"] = set(self.lazy_enforced)
|
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def iteration_cb(
|
def iteration_cb(
|
||||||
|
|||||||
@@ -87,9 +87,7 @@ class UserCutsComponent(Component):
|
|||||||
stats: LearningSolveStats,
|
stats: LearningSolveStats,
|
||||||
sample: Sample,
|
sample: Sample,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert sample.after_mip is not None
|
sample.put("user_cuts_enforced", set(self.enforced))
|
||||||
assert sample.after_mip.extra is not None
|
|
||||||
sample.after_mip.extra["user_cuts_enforced"] = set(self.enforced)
|
|
||||||
stats["UserCuts: Added in callback"] = self.n_added_in_callback
|
stats["UserCuts: Added in callback"] = self.n_added_in_callback
|
||||||
if self.n_added_in_callback > 0:
|
if self.n_added_in_callback > 0:
|
||||||
logger.info(f"{self.n_added_in_callback} user cuts added in callback")
|
logger.info(f"{self.n_added_in_callback} user cuts added in callback")
|
||||||
|
|||||||
@@ -60,9 +60,7 @@ class StaticLazyConstraintsComponent(Component):
|
|||||||
stats: LearningSolveStats,
|
stats: LearningSolveStats,
|
||||||
sample: Sample,
|
sample: Sample,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert sample.after_mip is not None
|
sample.put("lazy_enforced", self.enforced_cids)
|
||||||
assert sample.after_mip.extra is not None
|
|
||||||
sample.after_mip.extra["lazy_enforced"] = self.enforced_cids
|
|
||||||
stats["LazyStatic: Restored"] = self.n_restored
|
stats["LazyStatic: Restored"] = self.n_restored
|
||||||
stats["LazyStatic: Iterations"] = self.n_iterations
|
stats["LazyStatic: Iterations"] = self.n_iterations
|
||||||
|
|
||||||
@@ -236,12 +234,9 @@ class StaticLazyConstraintsComponent(Component):
|
|||||||
cids[category].append(cname)
|
cids[category].append(cname)
|
||||||
|
|
||||||
# Labels
|
# Labels
|
||||||
if (
|
lazy_enforced = sample.get("lazy_enforced")
|
||||||
(sample.after_mip is not None)
|
if lazy_enforced is not None:
|
||||||
and (sample.after_mip.extra is not None)
|
if cname in lazy_enforced:
|
||||||
and ("lazy_enforced" in sample.after_mip.extra)
|
|
||||||
):
|
|
||||||
if cname in sample.after_mip.extra["lazy_enforced"]:
|
|
||||||
y[category] += [[False, True]]
|
y[category] += [[False, True]]
|
||||||
else:
|
else:
|
||||||
y[category] += [[True, False]]
|
y[category] += [[True, False]]
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import collections
|
|||||||
import numbers
|
import numbers
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from math import log, isfinite
|
from math import log, isfinite
|
||||||
from typing import TYPE_CHECKING, Dict, Optional, List, Hashable, Tuple
|
from typing import TYPE_CHECKING, Dict, Optional, List, Hashable, Tuple, Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -140,14 +140,31 @@ class Features:
|
|||||||
constraints: Optional[ConstraintFeatures] = None
|
constraints: Optional[ConstraintFeatures] = None
|
||||||
lp_solve: Optional["LPSolveStats"] = None
|
lp_solve: Optional["LPSolveStats"] = None
|
||||||
mip_solve: Optional["MIPSolveStats"] = None
|
mip_solve: Optional["MIPSolveStats"] = None
|
||||||
extra: Optional[Dict] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Sample:
|
class Sample:
|
||||||
after_load: Optional[Features] = None
|
def __init__(
|
||||||
after_lp: Optional[Features] = None
|
self,
|
||||||
after_mip: Optional[Features] = None
|
after_load: Optional[Features] = None,
|
||||||
|
after_lp: Optional[Features] = None,
|
||||||
|
after_mip: Optional[Features] = None,
|
||||||
|
data: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> None:
|
||||||
|
if data is None:
|
||||||
|
data = {}
|
||||||
|
self._data: Dict[str, Any] = data
|
||||||
|
self.after_load = after_load
|
||||||
|
self.after_lp = after_lp
|
||||||
|
self.after_mip = after_mip
|
||||||
|
|
||||||
|
def get(self, key: str) -> Optional[Any]:
|
||||||
|
if key in self._data:
|
||||||
|
return self._data[key]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def put(self, key: str, value: Any) -> None:
|
||||||
|
self._data[key] = value
|
||||||
|
|
||||||
|
|
||||||
class FeaturesExtractor:
|
class FeaturesExtractor:
|
||||||
|
|||||||
@@ -173,7 +173,6 @@ class LearningSolver:
|
|||||||
"Features (after-load) extracted in %.2f seconds"
|
"Features (after-load) extracted in %.2f seconds"
|
||||||
% (time.time() - initial_time)
|
% (time.time() - initial_time)
|
||||||
)
|
)
|
||||||
features.extra = {}
|
|
||||||
sample.after_load = features
|
sample.after_load = features
|
||||||
|
|
||||||
callback_args = (
|
callback_args = (
|
||||||
@@ -217,7 +216,6 @@ class LearningSolver:
|
|||||||
"Features (after-lp) extracted in %.2f seconds"
|
"Features (after-lp) extracted in %.2f seconds"
|
||||||
% (time.time() - initial_time)
|
% (time.time() - initial_time)
|
||||||
)
|
)
|
||||||
features.extra = {}
|
|
||||||
features.lp_solve = lp_stats
|
features.lp_solve = lp_stats
|
||||||
sample.after_lp = features
|
sample.after_lp = features
|
||||||
|
|
||||||
@@ -291,7 +289,6 @@ class LearningSolver:
|
|||||||
% (time.time() - initial_time)
|
% (time.time() - initial_time)
|
||||||
)
|
)
|
||||||
features.mip_solve = mip_stats
|
features.mip_solve = mip_stats
|
||||||
features.extra = {}
|
|
||||||
sample.after_mip = features
|
sample.after_mip = features
|
||||||
|
|
||||||
# After-solve callbacks
|
# After-solve callbacks
|
||||||
|
|||||||
@@ -28,11 +28,11 @@ def training_instances() -> List[Instance]:
|
|||||||
samples_0 = [
|
samples_0 = [
|
||||||
Sample(
|
Sample(
|
||||||
after_load=Features(instance=InstanceFeatures()),
|
after_load=Features(instance=InstanceFeatures()),
|
||||||
after_mip=Features(extra={"lazy_enforced": {"c1", "c2"}}),
|
data={"lazy_enforced": {"c1", "c2"}},
|
||||||
),
|
),
|
||||||
Sample(
|
Sample(
|
||||||
after_load=Features(instance=InstanceFeatures()),
|
after_load=Features(instance=InstanceFeatures()),
|
||||||
after_mip=Features(extra={"lazy_enforced": {"c2", "c3"}}),
|
data={"lazy_enforced": {"c2", "c3"}},
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
samples_0[0].after_load.instance.to_list = Mock(return_value=[5.0]) # type: ignore
|
samples_0[0].after_load.instance.to_list = Mock(return_value=[5.0]) # type: ignore
|
||||||
@@ -57,7 +57,7 @@ def training_instances() -> List[Instance]:
|
|||||||
samples_1 = [
|
samples_1 = [
|
||||||
Sample(
|
Sample(
|
||||||
after_load=Features(instance=InstanceFeatures()),
|
after_load=Features(instance=InstanceFeatures()),
|
||||||
after_mip=Features(extra={"lazy_enforced": {"c3", "c4"}}),
|
data={"lazy_enforced": {"c3", "c4"}},
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
samples_1[0].after_load.instance.to_list = Mock(return_value=[8.0]) # type: ignore
|
samples_1[0].after_load.instance.to_list = Mock(return_value=[8.0]) # type: ignore
|
||||||
|
|||||||
@@ -81,10 +81,9 @@ def test_usage(
|
|||||||
) -> None:
|
) -> None:
|
||||||
stats_before = solver.solve(stab_instance)
|
stats_before = solver.solve(stab_instance)
|
||||||
sample = stab_instance.get_samples()[0]
|
sample = stab_instance.get_samples()[0]
|
||||||
assert sample.after_mip is not None
|
user_cuts_enforced = sample.get("user_cuts_enforced")
|
||||||
assert sample.after_mip.extra is not None
|
assert user_cuts_enforced is not None
|
||||||
assert len(sample.after_mip.extra["user_cuts_enforced"]) > 0
|
assert len(user_cuts_enforced) > 0
|
||||||
print(stats_before)
|
|
||||||
assert stats_before["UserCuts: Added ahead-of-time"] == 0
|
assert stats_before["UserCuts: Added ahead-of-time"] == 0
|
||||||
assert stats_before["UserCuts: Added in callback"] > 0
|
assert stats_before["UserCuts: Added in callback"] > 0
|
||||||
|
|
||||||
|
|||||||
@@ -48,11 +48,9 @@ def sample() -> Sample:
|
|||||||
instance=InstanceFeatures(),
|
instance=InstanceFeatures(),
|
||||||
constraints=ConstraintFeatures(names=["c1", "c2", "c3", "c4", "c5"]),
|
constraints=ConstraintFeatures(names=["c1", "c2", "c3", "c4", "c5"]),
|
||||||
),
|
),
|
||||||
after_mip=Features(
|
data={
|
||||||
extra={
|
|
||||||
"lazy_enforced": {"c1", "c2", "c4"},
|
"lazy_enforced": {"c1", "c2", "c4"},
|
||||||
}
|
},
|
||||||
),
|
|
||||||
)
|
)
|
||||||
sample.after_lp.instance.to_list = Mock(return_value=[5.0]) # type: ignore
|
sample.after_lp.instance.to_list = Mock(return_value=[5.0]) # type: ignore
|
||||||
sample.after_lp.constraints.to_list = Mock( # type: ignore
|
sample.after_lp.constraints.to_list = Mock( # type: ignore
|
||||||
@@ -112,10 +110,7 @@ def test_usage_with_solver(instance: Instance) -> None:
|
|||||||
|
|
||||||
stats: LearningSolveStats = {}
|
stats: LearningSolveStats = {}
|
||||||
sample = instance.get_samples()[0]
|
sample = instance.get_samples()[0]
|
||||||
assert sample.after_load is not None
|
assert sample.get("lazy_enforced") is not None
|
||||||
assert sample.after_mip is not None
|
|
||||||
assert sample.after_mip.extra is not None
|
|
||||||
del sample.after_mip.extra["lazy_enforced"]
|
|
||||||
|
|
||||||
# LearningSolver calls before_solve_mip
|
# LearningSolver calls before_solve_mip
|
||||||
component.before_solve_mip(
|
component.before_solve_mip(
|
||||||
@@ -140,6 +135,7 @@ def test_usage_with_solver(instance: Instance) -> None:
|
|||||||
|
|
||||||
# Should ask internal solver to verify if constraints in the pool are
|
# Should ask internal solver to verify if constraints in the pool are
|
||||||
# satisfied and add the ones that are not
|
# satisfied and add the ones that are not
|
||||||
|
assert sample.after_load is not None
|
||||||
assert sample.after_load.constraints is not None
|
assert sample.after_load.constraints is not None
|
||||||
c = sample.after_load.constraints[[False, False, True, False, False]]
|
c = sample.after_load.constraints[[False, False, True, False, False]]
|
||||||
internal.are_constraints_satisfied.assert_called_once_with(c, tol=1.0)
|
internal.are_constraints_satisfied.assert_called_once_with(c, tol=1.0)
|
||||||
@@ -165,7 +161,7 @@ def test_usage_with_solver(instance: Instance) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Should update training sample
|
# Should update training sample
|
||||||
assert sample.after_mip.extra["lazy_enforced"] == {"c1", "c2", "c3", "c4"}
|
assert sample.get("lazy_enforced") == {"c1", "c2", "c3", "c4"}
|
||||||
#
|
#
|
||||||
# Should update stats
|
# Should update stats
|
||||||
assert stats["LazyStatic: Removed"] == 1
|
assert stats["LazyStatic: Removed"] == 1
|
||||||
|
|||||||
@@ -67,15 +67,14 @@ def test_subtour() -> None:
|
|||||||
instance = TravelingSalesmanInstance(n_cities, distances)
|
instance = TravelingSalesmanInstance(n_cities, distances)
|
||||||
solver = LearningSolver()
|
solver = LearningSolver()
|
||||||
solver.solve(instance)
|
solver.solve(instance)
|
||||||
assert len(instance.get_samples()) == 1
|
samples = instance.get_samples()
|
||||||
sample = instance.get_samples()[0]
|
assert len(samples) == 1
|
||||||
assert sample.after_mip is not None
|
sample = samples[0]
|
||||||
features = sample.after_mip
|
lazy_enforced = sample.get("lazy_enforced")
|
||||||
assert features.extra is not None
|
|
||||||
assert "lazy_enforced" in features.extra
|
|
||||||
lazy_enforced = features.extra["lazy_enforced"]
|
|
||||||
assert lazy_enforced is not None
|
assert lazy_enforced is not None
|
||||||
assert len(lazy_enforced) > 0
|
assert len(lazy_enforced) > 0
|
||||||
|
assert sample.after_mip is not None
|
||||||
|
features = sample.after_mip
|
||||||
assert features.variables is not None
|
assert features.variables is not None
|
||||||
assert features.variables.values == [
|
assert features.variables.values == [
|
||||||
1.0,
|
1.0,
|
||||||
|
|||||||
Reference in New Issue
Block a user