Rewrite StaticLazy.sample_xy

This commit is contained in:
2021-04-12 07:35:51 -05:00
parent 2979bd157c
commit bccf0e9860
3 changed files with 98 additions and 15 deletions

View File

@@ -16,6 +16,7 @@ from miplearn.features import (
InstanceFeatures,
Features,
Constraint,
Sample,
)
from miplearn.instance.base import Instance
from miplearn.solvers.internal import InternalSolver
@@ -25,6 +26,50 @@ from miplearn.types import (
)
@pytest.fixture
def sample() -> Sample:
sample = Sample(
after_load=Features(
constraints={
"c1": Constraint(category="type-a", lazy=True),
"c2": Constraint(category="type-a", lazy=True),
"c3": Constraint(category="type-a", lazy=True),
"c4": Constraint(category="type-b", lazy=True),
"c5": Constraint(category="type-b", lazy=False),
}
),
after_lp=Features(
instance=InstanceFeatures(),
constraints={
"c1": Constraint(),
"c2": Constraint(),
"c3": Constraint(),
"c4": Constraint(),
"c5": Constraint(),
},
),
after_mip=Features(
extra={
"lazy_enforced": {"c1", "c2", "c4"},
}
),
)
sample.after_lp.instance.to_list = Mock(return_value=[5.0]) # type: ignore
sample.after_lp.constraints["c1"].to_list = Mock( # type: ignore
return_value=[1.0, 1.0]
)
sample.after_lp.constraints["c2"].to_list = Mock( # type: ignore
return_value=[1.0, 2.0]
)
sample.after_lp.constraints["c3"].to_list = Mock( # type: ignore
return_value=[1.0, 3.0]
)
sample.after_lp.constraints["c4"].to_list = Mock( # type: ignore
return_value=[1.0, 4.0, 0.0]
)
return sample
@pytest.fixture
def instance(features: Features) -> Instance:
instance = Mock(spec=Instance)
@@ -34,7 +79,7 @@ def instance(features: Features) -> Instance:
@pytest.fixture
def sample() -> TrainingSample:
def sample2() -> TrainingSample:
return TrainingSample(
lazy_enforced={"c1", "c2", "c4"},
)
@@ -112,7 +157,7 @@ def test_usage_with_solver(instance: Instance) -> None:
)
)
sample: TrainingSample = TrainingSample()
sample2: TrainingSample = TrainingSample()
stats: LearningSolveStats = {}
# LearningSolver calls before_solve_mip
@@ -122,7 +167,7 @@ def test_usage_with_solver(instance: Instance) -> None:
model=None,
stats=stats,
features=instance.features,
training_data=sample,
training_data=sample2,
)
# Should ask ML to predict whether each lazy constraint should be enforced
@@ -160,11 +205,11 @@ def test_usage_with_solver(instance: Instance) -> None:
model=None,
stats=stats,
features=instance.features,
training_data=sample,
training_data=sample2,
)
# Should update training sample
assert sample.lazy_enforced == {"c1", "c2", "c3", "c4"}
assert sample2.lazy_enforced == {"c1", "c2", "c3", "c4"}
# Should update stats
assert stats["LazyStatic: Removed"] == 1
@@ -175,7 +220,7 @@ def test_usage_with_solver(instance: Instance) -> None:
def test_sample_predict(
instance: Instance,
sample: TrainingSample,
sample2: TrainingSample,
) -> None:
comp = StaticLazyConstraintsComponent()
comp.thresholds["type-a"] = MinProbabilityThreshold([0.5, 0.5])
@@ -194,7 +239,7 @@ def test_sample_predict(
[0.0, 1.0], # c4
]
)
pred = comp.sample_predict(instance, sample)
pred = comp.sample_predict(instance, sample2)
assert pred == ["c1", "c2", "c4"]
@@ -238,19 +283,16 @@ def test_fit_xy() -> None:
assert thr_b.fit.call_args[0][0] == clf_b # type: ignore
def test_sample_xy(
instance: Instance,
sample: TrainingSample,
) -> None:
def test_sample_xy(sample: Sample) -> None:
x_expected = {
"type-a": [[1.0, 1.0], [1.0, 2.0], [1.0, 3.0]],
"type-b": [[1.0, 4.0, 0.0]],
"type-a": [[5.0, 1.0, 1.0], [5.0, 1.0, 2.0], [5.0, 1.0, 3.0]],
"type-b": [[5.0, 1.0, 4.0, 0.0]],
}
y_expected = {
"type-a": [[False, True], [False, True], [True, False]],
"type-b": [[False, True]],
}
xy = StaticLazyConstraintsComponent().sample_xy_old(instance, sample)
xy = StaticLazyConstraintsComponent().sample_xy(sample)
assert xy is not None
x_actual, y_actual = xy
assert x_actual == x_expected