Replace instance.samples by instance.get/push_sample

This commit is contained in:
2021-06-29 16:49:24 -05:00
parent a5092cc2b9
commit 80281df8d8
9 changed files with 48 additions and 34 deletions

View File

@@ -25,7 +25,7 @@ E = 0.1
@pytest.fixture
def training_instances() -> List[Instance]:
instances = [cast(Instance, Mock(spec=Instance)) for _ in range(2)]
instances[0].samples = [
samples_0 = [
Sample(
after_load=Features(instance=InstanceFeatures()),
after_mip=Features(extra={"lazy_enforced": {"c1", "c2"}}),
@@ -35,12 +35,9 @@ def training_instances() -> List[Instance]:
after_mip=Features(extra={"lazy_enforced": {"c2", "c3"}}),
),
]
instances[0].samples[0].after_load.instance.to_list = Mock( # type: ignore
return_value=[5.0]
)
instances[0].samples[1].after_load.instance.to_list = Mock( # type: ignore
return_value=[5.0]
)
samples_0[0].after_load.instance.to_list = Mock(return_value=[5.0]) # type: ignore
samples_0[1].after_load.instance.to_list = Mock(return_value=[5.0]) # type: ignore
instances[0].get_samples = Mock(return_value=samples_0) # type: ignore
instances[0].get_constraint_categories = Mock( # type: ignore
return_value={
"c1": "type-a",
@@ -57,15 +54,14 @@ def training_instances() -> List[Instance]:
"c4": [3.0, 4.0],
}
)
instances[1].samples = [
samples_1 = [
Sample(
after_load=Features(instance=InstanceFeatures()),
after_mip=Features(extra={"lazy_enforced": {"c3", "c4"}}),
)
]
instances[1].samples[0].after_load.instance.to_list = Mock( # type: ignore
return_value=[8.0]
)
samples_1[0].after_load.instance.to_list = Mock(return_value=[8.0]) # type: ignore
instances[1].get_samples = Mock(return_value=samples_1) # type: ignore
instances[1].get_constraint_categories = Mock( # type: ignore
return_value={
"c1": None,
@@ -97,7 +93,7 @@ def test_sample_xy(training_instances: List[Instance]) -> None:
}
x_actual, y_actual = comp.sample_xy(
training_instances[0],
training_instances[0].samples[0],
training_instances[0].get_samples()[0],
)
assert_equals(x_actual, x_expected)
assert_equals(y_actual, y_expected)
@@ -184,12 +180,12 @@ def test_sample_predict_evaluate(training_instances: List[Instance]) -> None:
)
pred = comp.sample_predict(
training_instances[0],
training_instances[0].samples[0],
training_instances[0].get_samples()[0],
)
assert pred == ["c1", "c4"]
ev = comp.sample_evaluate(
training_instances[0],
training_instances[0].samples[0],
training_instances[0].get_samples()[0],
)
assert ev == {
"type-a": classifier_evaluation_dict(tp=1, fp=0, tn=0, fn=1),

View File

@@ -80,7 +80,7 @@ def test_usage(
solver: LearningSolver,
) -> None:
stats_before = solver.solve(stab_instance)
sample = stab_instance.samples[0]
sample = stab_instance.get_samples()[0]
assert sample.after_mip is not None
assert sample.after_mip.extra is not None
assert len(sample.after_mip.extra["user_cuts_enforced"]) > 0

View File

@@ -70,7 +70,7 @@ def sample() -> Sample:
@pytest.fixture
def instance(sample: Sample) -> Instance:
instance = Mock(spec=Instance)
instance.samples = [sample]
instance.get_samples = Mock(return_value=[sample]) # type: ignore
instance.has_static_lazy_constraints = Mock(return_value=True)
return instance
@@ -111,7 +111,7 @@ def test_usage_with_solver(instance: Instance) -> None:
)
stats: LearningSolveStats = {}
sample = instance.samples[0]
sample = instance.get_samples()[0]
assert sample.after_load is not None
assert sample.after_mip is not None
assert sample.after_mip.extra is not None

View File

@@ -39,9 +39,10 @@ def test_instance() -> None:
instance = TravelingSalesmanInstance(n_cities, distances)
solver = LearningSolver()
solver.solve(instance)
assert len(instance.samples) == 1
assert instance.samples[0].after_mip is not None
features = instance.samples[0].after_mip
assert len(instance.get_samples()) == 1
sample = instance.get_samples()[0]
assert sample.after_mip is not None
features = sample.after_mip
assert features is not None
assert features.variables is not None
assert features.variables.values == [1.0, 0.0, 1.0, 1.0, 0.0, 1.0]
@@ -66,9 +67,10 @@ def test_subtour() -> None:
instance = TravelingSalesmanInstance(n_cities, distances)
solver = LearningSolver()
solver.solve(instance)
assert len(instance.samples) == 1
assert instance.samples[0].after_mip is not None
features = instance.samples[0].after_mip
assert len(instance.get_samples()) == 1
sample = instance.get_samples()[0]
assert sample.after_mip is not None
features = sample.after_mip
assert features.extra is not None
assert "lazy_enforced" in features.extra
lazy_enforced = features.extra["lazy_enforced"]

View File

@@ -35,8 +35,8 @@ def test_learning_solver(
)
solver.solve(instance)
assert len(instance.samples) > 0
sample = instance.samples[0]
assert len(instance.get_samples()) > 0
sample = instance.get_samples()[0]
after_mip = sample.after_mip
assert after_mip is not None
@@ -90,7 +90,7 @@ def test_parallel_solve(
results = solver.parallel_solve(instances, n_jobs=3)
assert len(results) == 10
for instance in instances:
assert len(instance.samples) == 1
assert len(instance.get_samples()) == 1
def test_solve_fit_from_disk(
@@ -109,13 +109,13 @@ def test_solve_fit_from_disk(
solver = LearningSolver(solver=internal_solver)
solver.solve(instances[0])
instance_loaded = read_pickle_gz(cast(PickleGzInstance, instances[0]).filename)
assert len(instance_loaded.samples) > 0
assert len(instance_loaded.get_samples()) > 0
# Test: parallel_solve
solver.parallel_solve(instances)
for instance in instances:
instance_loaded = read_pickle_gz(cast(PickleGzInstance, instance).filename)
assert len(instance_loaded.samples) > 0
assert len(instance_loaded.get_samples()) > 0
# Delete temporary files
for instance in instances: