Only include static features in after-load

This commit is contained in:
2021-04-13 16:08:30 -05:00
parent 8f41278713
commit ef7a50e871
11 changed files with 170 additions and 182 deletions

View File

@@ -28,18 +28,18 @@ def training_instances() -> List[Instance]:
instances = [cast(Instance, Mock(spec=Instance)) for _ in range(2)]
instances[0].samples = [
Sample(
after_lp=Features(instance=InstanceFeatures()),
after_load=Features(instance=InstanceFeatures()),
after_mip=Features(extra={"lazy_enforced": {"c1", "c2"}}),
),
Sample(
after_lp=Features(instance=InstanceFeatures()),
after_load=Features(instance=InstanceFeatures()),
after_mip=Features(extra={"lazy_enforced": {"c2", "c3"}}),
),
]
instances[0].samples[0].after_lp.instance.to_list = Mock( # type: ignore
instances[0].samples[0].after_load.instance.to_list = Mock( # type: ignore
return_value=[5.0]
)
instances[0].samples[1].after_lp.instance.to_list = Mock( # type: ignore
instances[0].samples[1].after_load.instance.to_list = Mock( # type: ignore
return_value=[5.0]
)
instances[0].get_constraint_category = Mock( # type: ignore
@@ -60,11 +60,11 @@ def training_instances() -> List[Instance]:
)
instances[1].samples = [
Sample(
after_lp=Features(instance=InstanceFeatures()),
after_load=Features(instance=InstanceFeatures()),
after_mip=Features(extra={"lazy_enforced": {"c3", "c4"}}),
)
]
instances[1].samples[0].after_lp.instance.to_list = Mock( # type: ignore
instances[1].samples[0].after_load.instance.to_list = Mock( # type: ignore
return_value=[8.0]
)
instances[1].get_constraint_category = Mock( # type: ignore

View File

@@ -27,6 +27,7 @@ from miplearn.solvers.tests import assert_equals
def sample() -> Sample:
sample = Sample(
after_load=Features(
instance=InstanceFeatures(),
variables={
"x[0]": Variable(category="default"),
"x[1]": Variable(category=None),
@@ -35,7 +36,6 @@ def sample() -> Sample:
},
),
after_lp=Features(
instance=InstanceFeatures(),
variables={
"x[0]": Variable(),
"x[1]": Variable(),
@@ -52,7 +52,7 @@ def sample() -> Sample:
}
),
)
sample.after_lp.instance.to_list = Mock(return_value=[5.0]) # type: ignore
sample.after_load.instance.to_list = Mock(return_value=[5.0]) # type: ignore
sample.after_lp.variables["x[0]"].to_list = Mock( # type: ignore
return_value=[0.0, 0.0]
)