mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Replace instance.samples by instance.get/push_sample
This commit is contained in:
@@ -187,13 +187,14 @@ class Component:
|
|||||||
instances: List[Instance],
|
instances: List[Instance],
|
||||||
n_jobs: int = 1,
|
n_jobs: int = 1,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
# Part I: Pre-fit
|
# Part I: Pre-fit
|
||||||
def _pre_sample_xy(instance: Instance) -> Dict:
|
def _pre_sample_xy(instance: Instance) -> Dict:
|
||||||
pre_instance: Dict = {}
|
pre_instance: Dict = {}
|
||||||
for (cidx, comp) in enumerate(components):
|
for (cidx, comp) in enumerate(components):
|
||||||
pre_instance[cidx] = []
|
pre_instance[cidx] = []
|
||||||
instance.load()
|
instance.load()
|
||||||
for sample in instance.samples:
|
for sample in instance.get_samples():
|
||||||
for (cidx, comp) in enumerate(components):
|
for (cidx, comp) in enumerate(components):
|
||||||
pre_instance[cidx].append(comp.pre_sample_xy(instance, sample))
|
pre_instance[cidx].append(comp.pre_sample_xy(instance, sample))
|
||||||
instance.free()
|
instance.free()
|
||||||
@@ -219,7 +220,7 @@ class Component:
|
|||||||
x_instance[cidx] = {}
|
x_instance[cidx] = {}
|
||||||
y_instance[cidx] = {}
|
y_instance[cidx] = {}
|
||||||
instance.load()
|
instance.load()
|
||||||
for sample in instance.samples:
|
for sample in instance.get_samples():
|
||||||
for (cidx, comp) in enumerate(components):
|
for (cidx, comp) in enumerate(components):
|
||||||
x = x_instance[cidx]
|
x = x_instance[cidx]
|
||||||
y = y_instance[cidx]
|
y = y_instance[cidx]
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ class Instance(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.samples: List[Sample] = []
|
self._samples: List[Sample] = []
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def to_model(self) -> Any:
|
def to_model(self) -> Any:
|
||||||
@@ -189,3 +189,9 @@ class Instance(ABC):
|
|||||||
Save any pending changes made to the instance to the underlying data store.
|
Save any pending changes made to the instance to the underlying data store.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def get_samples(self) -> List[Sample]:
|
||||||
|
return self._samples
|
||||||
|
|
||||||
|
def push_sample(self, sample: Sample) -> None:
|
||||||
|
self._samples.append(sample)
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from typing import Optional, Any, List, Hashable, cast, IO, TYPE_CHECKING, Dict
|
|||||||
|
|
||||||
from overrides import overrides
|
from overrides import overrides
|
||||||
|
|
||||||
|
from miplearn.features import Sample
|
||||||
from miplearn.instance.base import Instance
|
from miplearn.instance.base import Instance
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -120,18 +121,26 @@ class PickleGzInstance(Instance):
|
|||||||
obj = read_pickle_gz(self.filename)
|
obj = read_pickle_gz(self.filename)
|
||||||
assert isinstance(obj, Instance)
|
assert isinstance(obj, Instance)
|
||||||
self.instance = obj
|
self.instance = obj
|
||||||
self.samples = self.instance.samples
|
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def free(self) -> None:
|
def free(self) -> None:
|
||||||
self.instance = None # type: ignore
|
self.instance = None # type: ignore
|
||||||
self.samples = None # type: ignore
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def flush(self) -> None:
|
def flush(self) -> None:
|
||||||
write_pickle_gz(self.instance, self.filename)
|
write_pickle_gz(self.instance, self.filename)
|
||||||
|
|
||||||
|
@overrides
|
||||||
|
def get_samples(self) -> List[Sample]:
|
||||||
|
assert self.instance is not None
|
||||||
|
return self.instance.get_samples()
|
||||||
|
|
||||||
|
@overrides
|
||||||
|
def push_sample(self, sample: Sample) -> None:
|
||||||
|
assert self.instance is not None
|
||||||
|
self.instance.push_sample(sample)
|
||||||
|
|
||||||
|
|
||||||
def write_pickle_gz(obj: Any, filename: str) -> None:
|
def write_pickle_gz(obj: Any, filename: str) -> None:
|
||||||
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
||||||
|
|||||||
@@ -150,7 +150,7 @@ class LearningSolver:
|
|||||||
# Initialize training sample
|
# Initialize training sample
|
||||||
# -------------------------------------------------------
|
# -------------------------------------------------------
|
||||||
sample = Sample()
|
sample = Sample()
|
||||||
instance.samples.append(sample)
|
instance.push_sample(sample)
|
||||||
|
|
||||||
# Initialize stats
|
# Initialize stats
|
||||||
# -------------------------------------------------------
|
# -------------------------------------------------------
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ E = 0.1
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def training_instances() -> List[Instance]:
|
def training_instances() -> List[Instance]:
|
||||||
instances = [cast(Instance, Mock(spec=Instance)) for _ in range(2)]
|
instances = [cast(Instance, Mock(spec=Instance)) for _ in range(2)]
|
||||||
instances[0].samples = [
|
samples_0 = [
|
||||||
Sample(
|
Sample(
|
||||||
after_load=Features(instance=InstanceFeatures()),
|
after_load=Features(instance=InstanceFeatures()),
|
||||||
after_mip=Features(extra={"lazy_enforced": {"c1", "c2"}}),
|
after_mip=Features(extra={"lazy_enforced": {"c1", "c2"}}),
|
||||||
@@ -35,12 +35,9 @@ def training_instances() -> List[Instance]:
|
|||||||
after_mip=Features(extra={"lazy_enforced": {"c2", "c3"}}),
|
after_mip=Features(extra={"lazy_enforced": {"c2", "c3"}}),
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
instances[0].samples[0].after_load.instance.to_list = Mock( # type: ignore
|
samples_0[0].after_load.instance.to_list = Mock(return_value=[5.0]) # type: ignore
|
||||||
return_value=[5.0]
|
samples_0[1].after_load.instance.to_list = Mock(return_value=[5.0]) # type: ignore
|
||||||
)
|
instances[0].get_samples = Mock(return_value=samples_0) # type: ignore
|
||||||
instances[0].samples[1].after_load.instance.to_list = Mock( # type: ignore
|
|
||||||
return_value=[5.0]
|
|
||||||
)
|
|
||||||
instances[0].get_constraint_categories = Mock( # type: ignore
|
instances[0].get_constraint_categories = Mock( # type: ignore
|
||||||
return_value={
|
return_value={
|
||||||
"c1": "type-a",
|
"c1": "type-a",
|
||||||
@@ -57,15 +54,14 @@ def training_instances() -> List[Instance]:
|
|||||||
"c4": [3.0, 4.0],
|
"c4": [3.0, 4.0],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
instances[1].samples = [
|
samples_1 = [
|
||||||
Sample(
|
Sample(
|
||||||
after_load=Features(instance=InstanceFeatures()),
|
after_load=Features(instance=InstanceFeatures()),
|
||||||
after_mip=Features(extra={"lazy_enforced": {"c3", "c4"}}),
|
after_mip=Features(extra={"lazy_enforced": {"c3", "c4"}}),
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
instances[1].samples[0].after_load.instance.to_list = Mock( # type: ignore
|
samples_1[0].after_load.instance.to_list = Mock(return_value=[8.0]) # type: ignore
|
||||||
return_value=[8.0]
|
instances[1].get_samples = Mock(return_value=samples_1) # type: ignore
|
||||||
)
|
|
||||||
instances[1].get_constraint_categories = Mock( # type: ignore
|
instances[1].get_constraint_categories = Mock( # type: ignore
|
||||||
return_value={
|
return_value={
|
||||||
"c1": None,
|
"c1": None,
|
||||||
@@ -97,7 +93,7 @@ def test_sample_xy(training_instances: List[Instance]) -> None:
|
|||||||
}
|
}
|
||||||
x_actual, y_actual = comp.sample_xy(
|
x_actual, y_actual = comp.sample_xy(
|
||||||
training_instances[0],
|
training_instances[0],
|
||||||
training_instances[0].samples[0],
|
training_instances[0].get_samples()[0],
|
||||||
)
|
)
|
||||||
assert_equals(x_actual, x_expected)
|
assert_equals(x_actual, x_expected)
|
||||||
assert_equals(y_actual, y_expected)
|
assert_equals(y_actual, y_expected)
|
||||||
@@ -184,12 +180,12 @@ def test_sample_predict_evaluate(training_instances: List[Instance]) -> None:
|
|||||||
)
|
)
|
||||||
pred = comp.sample_predict(
|
pred = comp.sample_predict(
|
||||||
training_instances[0],
|
training_instances[0],
|
||||||
training_instances[0].samples[0],
|
training_instances[0].get_samples()[0],
|
||||||
)
|
)
|
||||||
assert pred == ["c1", "c4"]
|
assert pred == ["c1", "c4"]
|
||||||
ev = comp.sample_evaluate(
|
ev = comp.sample_evaluate(
|
||||||
training_instances[0],
|
training_instances[0],
|
||||||
training_instances[0].samples[0],
|
training_instances[0].get_samples()[0],
|
||||||
)
|
)
|
||||||
assert ev == {
|
assert ev == {
|
||||||
"type-a": classifier_evaluation_dict(tp=1, fp=0, tn=0, fn=1),
|
"type-a": classifier_evaluation_dict(tp=1, fp=0, tn=0, fn=1),
|
||||||
|
|||||||
@@ -80,7 +80,7 @@ def test_usage(
|
|||||||
solver: LearningSolver,
|
solver: LearningSolver,
|
||||||
) -> None:
|
) -> None:
|
||||||
stats_before = solver.solve(stab_instance)
|
stats_before = solver.solve(stab_instance)
|
||||||
sample = stab_instance.samples[0]
|
sample = stab_instance.get_samples()[0]
|
||||||
assert sample.after_mip is not None
|
assert sample.after_mip is not None
|
||||||
assert sample.after_mip.extra is not None
|
assert sample.after_mip.extra is not None
|
||||||
assert len(sample.after_mip.extra["user_cuts_enforced"]) > 0
|
assert len(sample.after_mip.extra["user_cuts_enforced"]) > 0
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ def sample() -> Sample:
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def instance(sample: Sample) -> Instance:
|
def instance(sample: Sample) -> Instance:
|
||||||
instance = Mock(spec=Instance)
|
instance = Mock(spec=Instance)
|
||||||
instance.samples = [sample]
|
instance.get_samples = Mock(return_value=[sample]) # type: ignore
|
||||||
instance.has_static_lazy_constraints = Mock(return_value=True)
|
instance.has_static_lazy_constraints = Mock(return_value=True)
|
||||||
return instance
|
return instance
|
||||||
|
|
||||||
@@ -111,7 +111,7 @@ def test_usage_with_solver(instance: Instance) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
stats: LearningSolveStats = {}
|
stats: LearningSolveStats = {}
|
||||||
sample = instance.samples[0]
|
sample = instance.get_samples()[0]
|
||||||
assert sample.after_load is not None
|
assert sample.after_load is not None
|
||||||
assert sample.after_mip is not None
|
assert sample.after_mip is not None
|
||||||
assert sample.after_mip.extra is not None
|
assert sample.after_mip.extra is not None
|
||||||
|
|||||||
@@ -39,9 +39,10 @@ def test_instance() -> None:
|
|||||||
instance = TravelingSalesmanInstance(n_cities, distances)
|
instance = TravelingSalesmanInstance(n_cities, distances)
|
||||||
solver = LearningSolver()
|
solver = LearningSolver()
|
||||||
solver.solve(instance)
|
solver.solve(instance)
|
||||||
assert len(instance.samples) == 1
|
assert len(instance.get_samples()) == 1
|
||||||
assert instance.samples[0].after_mip is not None
|
sample = instance.get_samples()[0]
|
||||||
features = instance.samples[0].after_mip
|
assert sample.after_mip is not None
|
||||||
|
features = sample.after_mip
|
||||||
assert features is not None
|
assert features is not None
|
||||||
assert features.variables is not None
|
assert features.variables is not None
|
||||||
assert features.variables.values == [1.0, 0.0, 1.0, 1.0, 0.0, 1.0]
|
assert features.variables.values == [1.0, 0.0, 1.0, 1.0, 0.0, 1.0]
|
||||||
@@ -66,9 +67,10 @@ def test_subtour() -> None:
|
|||||||
instance = TravelingSalesmanInstance(n_cities, distances)
|
instance = TravelingSalesmanInstance(n_cities, distances)
|
||||||
solver = LearningSolver()
|
solver = LearningSolver()
|
||||||
solver.solve(instance)
|
solver.solve(instance)
|
||||||
assert len(instance.samples) == 1
|
assert len(instance.get_samples()) == 1
|
||||||
assert instance.samples[0].after_mip is not None
|
sample = instance.get_samples()[0]
|
||||||
features = instance.samples[0].after_mip
|
assert sample.after_mip is not None
|
||||||
|
features = sample.after_mip
|
||||||
assert features.extra is not None
|
assert features.extra is not None
|
||||||
assert "lazy_enforced" in features.extra
|
assert "lazy_enforced" in features.extra
|
||||||
lazy_enforced = features.extra["lazy_enforced"]
|
lazy_enforced = features.extra["lazy_enforced"]
|
||||||
|
|||||||
@@ -35,8 +35,8 @@ def test_learning_solver(
|
|||||||
)
|
)
|
||||||
|
|
||||||
solver.solve(instance)
|
solver.solve(instance)
|
||||||
assert len(instance.samples) > 0
|
assert len(instance.get_samples()) > 0
|
||||||
sample = instance.samples[0]
|
sample = instance.get_samples()[0]
|
||||||
|
|
||||||
after_mip = sample.after_mip
|
after_mip = sample.after_mip
|
||||||
assert after_mip is not None
|
assert after_mip is not None
|
||||||
@@ -90,7 +90,7 @@ def test_parallel_solve(
|
|||||||
results = solver.parallel_solve(instances, n_jobs=3)
|
results = solver.parallel_solve(instances, n_jobs=3)
|
||||||
assert len(results) == 10
|
assert len(results) == 10
|
||||||
for instance in instances:
|
for instance in instances:
|
||||||
assert len(instance.samples) == 1
|
assert len(instance.get_samples()) == 1
|
||||||
|
|
||||||
|
|
||||||
def test_solve_fit_from_disk(
|
def test_solve_fit_from_disk(
|
||||||
@@ -109,13 +109,13 @@ def test_solve_fit_from_disk(
|
|||||||
solver = LearningSolver(solver=internal_solver)
|
solver = LearningSolver(solver=internal_solver)
|
||||||
solver.solve(instances[0])
|
solver.solve(instances[0])
|
||||||
instance_loaded = read_pickle_gz(cast(PickleGzInstance, instances[0]).filename)
|
instance_loaded = read_pickle_gz(cast(PickleGzInstance, instances[0]).filename)
|
||||||
assert len(instance_loaded.samples) > 0
|
assert len(instance_loaded.get_samples()) > 0
|
||||||
|
|
||||||
# Test: parallel_solve
|
# Test: parallel_solve
|
||||||
solver.parallel_solve(instances)
|
solver.parallel_solve(instances)
|
||||||
for instance in instances:
|
for instance in instances:
|
||||||
instance_loaded = read_pickle_gz(cast(PickleGzInstance, instance).filename)
|
instance_loaded = read_pickle_gz(cast(PickleGzInstance, instance).filename)
|
||||||
assert len(instance_loaded.samples) > 0
|
assert len(instance_loaded.get_samples()) > 0
|
||||||
|
|
||||||
# Delete temporary files
|
# Delete temporary files
|
||||||
for instance in instances:
|
for instance in instances:
|
||||||
|
|||||||
Reference in New Issue
Block a user