Replace instance.samples by instance.get/push_sample

master
Alinson S. Xavier 4 years ago
parent a5092cc2b9
commit 80281df8d8
No known key found for this signature in database
GPG Key ID: DCA0DAD4D2F58624

@ -187,13 +187,14 @@ class Component:
instances: List[Instance], instances: List[Instance],
n_jobs: int = 1, n_jobs: int = 1,
) -> None: ) -> None:
# Part I: Pre-fit # Part I: Pre-fit
def _pre_sample_xy(instance: Instance) -> Dict: def _pre_sample_xy(instance: Instance) -> Dict:
pre_instance: Dict = {} pre_instance: Dict = {}
for (cidx, comp) in enumerate(components): for (cidx, comp) in enumerate(components):
pre_instance[cidx] = [] pre_instance[cidx] = []
instance.load() instance.load()
for sample in instance.samples: for sample in instance.get_samples():
for (cidx, comp) in enumerate(components): for (cidx, comp) in enumerate(components):
pre_instance[cidx].append(comp.pre_sample_xy(instance, sample)) pre_instance[cidx].append(comp.pre_sample_xy(instance, sample))
instance.free() instance.free()
@ -219,7 +220,7 @@ class Component:
x_instance[cidx] = {} x_instance[cidx] = {}
y_instance[cidx] = {} y_instance[cidx] = {}
instance.load() instance.load()
for sample in instance.samples: for sample in instance.get_samples():
for (cidx, comp) in enumerate(components): for (cidx, comp) in enumerate(components):
x = x_instance[cidx] x = x_instance[cidx]
y = y_instance[cidx] y = y_instance[cidx]

@ -28,7 +28,7 @@ class Instance(ABC):
""" """
def __init__(self) -> None: def __init__(self) -> None:
self.samples: List[Sample] = [] self._samples: List[Sample] = []
@abstractmethod @abstractmethod
def to_model(self) -> Any: def to_model(self) -> Any:
@ -189,3 +189,9 @@ class Instance(ABC):
Save any pending changes made to the instance to the underlying data store. Save any pending changes made to the instance to the underlying data store.
""" """
pass pass
def get_samples(self) -> List[Sample]:
return self._samples
def push_sample(self, sample: Sample) -> None:
self._samples.append(sample)

@ -10,6 +10,7 @@ from typing import Optional, Any, List, Hashable, cast, IO, TYPE_CHECKING, Dict
from overrides import overrides from overrides import overrides
from miplearn.features import Sample
from miplearn.instance.base import Instance from miplearn.instance.base import Instance
if TYPE_CHECKING: if TYPE_CHECKING:
@ -120,18 +121,26 @@ class PickleGzInstance(Instance):
obj = read_pickle_gz(self.filename) obj = read_pickle_gz(self.filename)
assert isinstance(obj, Instance) assert isinstance(obj, Instance)
self.instance = obj self.instance = obj
self.samples = self.instance.samples
@overrides @overrides
def free(self) -> None: def free(self) -> None:
self.instance = None # type: ignore self.instance = None # type: ignore
self.samples = None # type: ignore
gc.collect() gc.collect()
@overrides @overrides
def flush(self) -> None: def flush(self) -> None:
write_pickle_gz(self.instance, self.filename) write_pickle_gz(self.instance, self.filename)
@overrides
def get_samples(self) -> List[Sample]:
assert self.instance is not None
return self.instance.get_samples()
@overrides
def push_sample(self, sample: Sample) -> None:
assert self.instance is not None
self.instance.push_sample(sample)
def write_pickle_gz(obj: Any, filename: str) -> None: def write_pickle_gz(obj: Any, filename: str) -> None:
os.makedirs(os.path.dirname(filename), exist_ok=True) os.makedirs(os.path.dirname(filename), exist_ok=True)

@ -150,7 +150,7 @@ class LearningSolver:
# Initialize training sample # Initialize training sample
# ------------------------------------------------------- # -------------------------------------------------------
sample = Sample() sample = Sample()
instance.samples.append(sample) instance.push_sample(sample)
# Initialize stats # Initialize stats
# ------------------------------------------------------- # -------------------------------------------------------

@ -25,7 +25,7 @@ E = 0.1
@pytest.fixture @pytest.fixture
def training_instances() -> List[Instance]: def training_instances() -> List[Instance]:
instances = [cast(Instance, Mock(spec=Instance)) for _ in range(2)] instances = [cast(Instance, Mock(spec=Instance)) for _ in range(2)]
instances[0].samples = [ samples_0 = [
Sample( Sample(
after_load=Features(instance=InstanceFeatures()), after_load=Features(instance=InstanceFeatures()),
after_mip=Features(extra={"lazy_enforced": {"c1", "c2"}}), after_mip=Features(extra={"lazy_enforced": {"c1", "c2"}}),
@ -35,12 +35,9 @@ def training_instances() -> List[Instance]:
after_mip=Features(extra={"lazy_enforced": {"c2", "c3"}}), after_mip=Features(extra={"lazy_enforced": {"c2", "c3"}}),
), ),
] ]
instances[0].samples[0].after_load.instance.to_list = Mock( # type: ignore samples_0[0].after_load.instance.to_list = Mock(return_value=[5.0]) # type: ignore
return_value=[5.0] samples_0[1].after_load.instance.to_list = Mock(return_value=[5.0]) # type: ignore
) instances[0].get_samples = Mock(return_value=samples_0) # type: ignore
instances[0].samples[1].after_load.instance.to_list = Mock( # type: ignore
return_value=[5.0]
)
instances[0].get_constraint_categories = Mock( # type: ignore instances[0].get_constraint_categories = Mock( # type: ignore
return_value={ return_value={
"c1": "type-a", "c1": "type-a",
@ -57,15 +54,14 @@ def training_instances() -> List[Instance]:
"c4": [3.0, 4.0], "c4": [3.0, 4.0],
} }
) )
instances[1].samples = [ samples_1 = [
Sample( Sample(
after_load=Features(instance=InstanceFeatures()), after_load=Features(instance=InstanceFeatures()),
after_mip=Features(extra={"lazy_enforced": {"c3", "c4"}}), after_mip=Features(extra={"lazy_enforced": {"c3", "c4"}}),
) )
] ]
instances[1].samples[0].after_load.instance.to_list = Mock( # type: ignore samples_1[0].after_load.instance.to_list = Mock(return_value=[8.0]) # type: ignore
return_value=[8.0] instances[1].get_samples = Mock(return_value=samples_1) # type: ignore
)
instances[1].get_constraint_categories = Mock( # type: ignore instances[1].get_constraint_categories = Mock( # type: ignore
return_value={ return_value={
"c1": None, "c1": None,
@ -97,7 +93,7 @@ def test_sample_xy(training_instances: List[Instance]) -> None:
} }
x_actual, y_actual = comp.sample_xy( x_actual, y_actual = comp.sample_xy(
training_instances[0], training_instances[0],
training_instances[0].samples[0], training_instances[0].get_samples()[0],
) )
assert_equals(x_actual, x_expected) assert_equals(x_actual, x_expected)
assert_equals(y_actual, y_expected) assert_equals(y_actual, y_expected)
@ -184,12 +180,12 @@ def test_sample_predict_evaluate(training_instances: List[Instance]) -> None:
) )
pred = comp.sample_predict( pred = comp.sample_predict(
training_instances[0], training_instances[0],
training_instances[0].samples[0], training_instances[0].get_samples()[0],
) )
assert pred == ["c1", "c4"] assert pred == ["c1", "c4"]
ev = comp.sample_evaluate( ev = comp.sample_evaluate(
training_instances[0], training_instances[0],
training_instances[0].samples[0], training_instances[0].get_samples()[0],
) )
assert ev == { assert ev == {
"type-a": classifier_evaluation_dict(tp=1, fp=0, tn=0, fn=1), "type-a": classifier_evaluation_dict(tp=1, fp=0, tn=0, fn=1),

@ -80,7 +80,7 @@ def test_usage(
solver: LearningSolver, solver: LearningSolver,
) -> None: ) -> None:
stats_before = solver.solve(stab_instance) stats_before = solver.solve(stab_instance)
sample = stab_instance.samples[0] sample = stab_instance.get_samples()[0]
assert sample.after_mip is not None assert sample.after_mip is not None
assert sample.after_mip.extra is not None assert sample.after_mip.extra is not None
assert len(sample.after_mip.extra["user_cuts_enforced"]) > 0 assert len(sample.after_mip.extra["user_cuts_enforced"]) > 0

@ -70,7 +70,7 @@ def sample() -> Sample:
@pytest.fixture @pytest.fixture
def instance(sample: Sample) -> Instance: def instance(sample: Sample) -> Instance:
instance = Mock(spec=Instance) instance = Mock(spec=Instance)
instance.samples = [sample] instance.get_samples = Mock(return_value=[sample]) # type: ignore
instance.has_static_lazy_constraints = Mock(return_value=True) instance.has_static_lazy_constraints = Mock(return_value=True)
return instance return instance
@ -111,7 +111,7 @@ def test_usage_with_solver(instance: Instance) -> None:
) )
stats: LearningSolveStats = {} stats: LearningSolveStats = {}
sample = instance.samples[0] sample = instance.get_samples()[0]
assert sample.after_load is not None assert sample.after_load is not None
assert sample.after_mip is not None assert sample.after_mip is not None
assert sample.after_mip.extra is not None assert sample.after_mip.extra is not None

@ -39,9 +39,10 @@ def test_instance() -> None:
instance = TravelingSalesmanInstance(n_cities, distances) instance = TravelingSalesmanInstance(n_cities, distances)
solver = LearningSolver() solver = LearningSolver()
solver.solve(instance) solver.solve(instance)
assert len(instance.samples) == 1 assert len(instance.get_samples()) == 1
assert instance.samples[0].after_mip is not None sample = instance.get_samples()[0]
features = instance.samples[0].after_mip assert sample.after_mip is not None
features = sample.after_mip
assert features is not None assert features is not None
assert features.variables is not None assert features.variables is not None
assert features.variables.values == [1.0, 0.0, 1.0, 1.0, 0.0, 1.0] assert features.variables.values == [1.0, 0.0, 1.0, 1.0, 0.0, 1.0]
@ -66,9 +67,10 @@ def test_subtour() -> None:
instance = TravelingSalesmanInstance(n_cities, distances) instance = TravelingSalesmanInstance(n_cities, distances)
solver = LearningSolver() solver = LearningSolver()
solver.solve(instance) solver.solve(instance)
assert len(instance.samples) == 1 assert len(instance.get_samples()) == 1
assert instance.samples[0].after_mip is not None sample = instance.get_samples()[0]
features = instance.samples[0].after_mip assert sample.after_mip is not None
features = sample.after_mip
assert features.extra is not None assert features.extra is not None
assert "lazy_enforced" in features.extra assert "lazy_enforced" in features.extra
lazy_enforced = features.extra["lazy_enforced"] lazy_enforced = features.extra["lazy_enforced"]

@ -35,8 +35,8 @@ def test_learning_solver(
) )
solver.solve(instance) solver.solve(instance)
assert len(instance.samples) > 0 assert len(instance.get_samples()) > 0
sample = instance.samples[0] sample = instance.get_samples()[0]
after_mip = sample.after_mip after_mip = sample.after_mip
assert after_mip is not None assert after_mip is not None
@ -90,7 +90,7 @@ def test_parallel_solve(
results = solver.parallel_solve(instances, n_jobs=3) results = solver.parallel_solve(instances, n_jobs=3)
assert len(results) == 10 assert len(results) == 10
for instance in instances: for instance in instances:
assert len(instance.samples) == 1 assert len(instance.get_samples()) == 1
def test_solve_fit_from_disk( def test_solve_fit_from_disk(
@ -109,13 +109,13 @@ def test_solve_fit_from_disk(
solver = LearningSolver(solver=internal_solver) solver = LearningSolver(solver=internal_solver)
solver.solve(instances[0]) solver.solve(instances[0])
instance_loaded = read_pickle_gz(cast(PickleGzInstance, instances[0]).filename) instance_loaded = read_pickle_gz(cast(PickleGzInstance, instances[0]).filename)
assert len(instance_loaded.samples) > 0 assert len(instance_loaded.get_samples()) > 0
# Test: parallel_solve # Test: parallel_solve
solver.parallel_solve(instances) solver.parallel_solve(instances)
for instance in instances: for instance in instances:
instance_loaded = read_pickle_gz(cast(PickleGzInstance, instance).filename) instance_loaded = read_pickle_gz(cast(PickleGzInstance, instance).filename)
assert len(instance_loaded.samples) > 0 assert len(instance_loaded.get_samples()) > 0
# Delete temporary files # Delete temporary files
for instance in instances: for instance in instances:

Loading…
Cancel
Save