mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Move user_cuts/lazy_enforced to sample.data
This commit is contained in:
@@ -28,11 +28,11 @@ def training_instances() -> List[Instance]:
|
||||
samples_0 = [
|
||||
Sample(
|
||||
after_load=Features(instance=InstanceFeatures()),
|
||||
after_mip=Features(extra={"lazy_enforced": {"c1", "c2"}}),
|
||||
data={"lazy_enforced": {"c1", "c2"}},
|
||||
),
|
||||
Sample(
|
||||
after_load=Features(instance=InstanceFeatures()),
|
||||
after_mip=Features(extra={"lazy_enforced": {"c2", "c3"}}),
|
||||
data={"lazy_enforced": {"c2", "c3"}},
|
||||
),
|
||||
]
|
||||
samples_0[0].after_load.instance.to_list = Mock(return_value=[5.0]) # type: ignore
|
||||
@@ -57,7 +57,7 @@ def training_instances() -> List[Instance]:
|
||||
samples_1 = [
|
||||
Sample(
|
||||
after_load=Features(instance=InstanceFeatures()),
|
||||
after_mip=Features(extra={"lazy_enforced": {"c3", "c4"}}),
|
||||
data={"lazy_enforced": {"c3", "c4"}},
|
||||
)
|
||||
]
|
||||
samples_1[0].after_load.instance.to_list = Mock(return_value=[8.0]) # type: ignore
|
||||
|
||||
@@ -81,10 +81,9 @@ def test_usage(
|
||||
) -> None:
|
||||
stats_before = solver.solve(stab_instance)
|
||||
sample = stab_instance.get_samples()[0]
|
||||
assert sample.after_mip is not None
|
||||
assert sample.after_mip.extra is not None
|
||||
assert len(sample.after_mip.extra["user_cuts_enforced"]) > 0
|
||||
print(stats_before)
|
||||
user_cuts_enforced = sample.get("user_cuts_enforced")
|
||||
assert user_cuts_enforced is not None
|
||||
assert len(user_cuts_enforced) > 0
|
||||
assert stats_before["UserCuts: Added ahead-of-time"] == 0
|
||||
assert stats_before["UserCuts: Added in callback"] > 0
|
||||
|
||||
|
||||
@@ -48,11 +48,9 @@ def sample() -> Sample:
|
||||
instance=InstanceFeatures(),
|
||||
constraints=ConstraintFeatures(names=["c1", "c2", "c3", "c4", "c5"]),
|
||||
),
|
||||
after_mip=Features(
|
||||
extra={
|
||||
"lazy_enforced": {"c1", "c2", "c4"},
|
||||
}
|
||||
),
|
||||
data={
|
||||
"lazy_enforced": {"c1", "c2", "c4"},
|
||||
},
|
||||
)
|
||||
sample.after_lp.instance.to_list = Mock(return_value=[5.0]) # type: ignore
|
||||
sample.after_lp.constraints.to_list = Mock( # type: ignore
|
||||
@@ -112,10 +110,7 @@ def test_usage_with_solver(instance: Instance) -> None:
|
||||
|
||||
stats: LearningSolveStats = {}
|
||||
sample = instance.get_samples()[0]
|
||||
assert sample.after_load is not None
|
||||
assert sample.after_mip is not None
|
||||
assert sample.after_mip.extra is not None
|
||||
del sample.after_mip.extra["lazy_enforced"]
|
||||
assert sample.get("lazy_enforced") is not None
|
||||
|
||||
# LearningSolver calls before_solve_mip
|
||||
component.before_solve_mip(
|
||||
@@ -140,6 +135,7 @@ def test_usage_with_solver(instance: Instance) -> None:
|
||||
|
||||
# Should ask internal solver to verify if constraints in the pool are
|
||||
# satisfied and add the ones that are not
|
||||
assert sample.after_load is not None
|
||||
assert sample.after_load.constraints is not None
|
||||
c = sample.after_load.constraints[[False, False, True, False, False]]
|
||||
internal.are_constraints_satisfied.assert_called_once_with(c, tol=1.0)
|
||||
@@ -165,7 +161,7 @@ def test_usage_with_solver(instance: Instance) -> None:
|
||||
)
|
||||
|
||||
# Should update training sample
|
||||
assert sample.after_mip.extra["lazy_enforced"] == {"c1", "c2", "c3", "c4"}
|
||||
assert sample.get("lazy_enforced") == {"c1", "c2", "c3", "c4"}
|
||||
#
|
||||
# Should update stats
|
||||
assert stats["LazyStatic: Removed"] == 1
|
||||
|
||||
@@ -67,15 +67,14 @@ def test_subtour() -> None:
|
||||
instance = TravelingSalesmanInstance(n_cities, distances)
|
||||
solver = LearningSolver()
|
||||
solver.solve(instance)
|
||||
assert len(instance.get_samples()) == 1
|
||||
sample = instance.get_samples()[0]
|
||||
assert sample.after_mip is not None
|
||||
features = sample.after_mip
|
||||
assert features.extra is not None
|
||||
assert "lazy_enforced" in features.extra
|
||||
lazy_enforced = features.extra["lazy_enforced"]
|
||||
samples = instance.get_samples()
|
||||
assert len(samples) == 1
|
||||
sample = samples[0]
|
||||
lazy_enforced = sample.get("lazy_enforced")
|
||||
assert lazy_enforced is not None
|
||||
assert len(lazy_enforced) > 0
|
||||
assert sample.after_mip is not None
|
||||
features = sample.after_mip
|
||||
assert features.variables is not None
|
||||
assert features.variables.values == [
|
||||
1.0,
|
||||
|
||||
Reference in New Issue
Block a user