mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Replace instance.samples by instance.get/push_sample
This commit is contained in:
@@ -187,13 +187,14 @@ class Component:
|
||||
instances: List[Instance],
|
||||
n_jobs: int = 1,
|
||||
) -> None:
|
||||
|
||||
# Part I: Pre-fit
|
||||
def _pre_sample_xy(instance: Instance) -> Dict:
|
||||
pre_instance: Dict = {}
|
||||
for (cidx, comp) in enumerate(components):
|
||||
pre_instance[cidx] = []
|
||||
instance.load()
|
||||
for sample in instance.samples:
|
||||
for sample in instance.get_samples():
|
||||
for (cidx, comp) in enumerate(components):
|
||||
pre_instance[cidx].append(comp.pre_sample_xy(instance, sample))
|
||||
instance.free()
|
||||
@@ -219,7 +220,7 @@ class Component:
|
||||
x_instance[cidx] = {}
|
||||
y_instance[cidx] = {}
|
||||
instance.load()
|
||||
for sample in instance.samples:
|
||||
for sample in instance.get_samples():
|
||||
for (cidx, comp) in enumerate(components):
|
||||
x = x_instance[cidx]
|
||||
y = y_instance[cidx]
|
||||
|
||||
@@ -28,7 +28,7 @@ class Instance(ABC):
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.samples: List[Sample] = []
|
||||
self._samples: List[Sample] = []
|
||||
|
||||
@abstractmethod
|
||||
def to_model(self) -> Any:
|
||||
@@ -189,3 +189,9 @@ class Instance(ABC):
|
||||
Save any pending changes made to the instance to the underlying data store.
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_samples(self) -> List[Sample]:
|
||||
return self._samples
|
||||
|
||||
def push_sample(self, sample: Sample) -> None:
|
||||
self._samples.append(sample)
|
||||
|
||||
@@ -10,6 +10,7 @@ from typing import Optional, Any, List, Hashable, cast, IO, TYPE_CHECKING, Dict
|
||||
|
||||
from overrides import overrides
|
||||
|
||||
from miplearn.features import Sample
|
||||
from miplearn.instance.base import Instance
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -120,18 +121,26 @@ class PickleGzInstance(Instance):
|
||||
obj = read_pickle_gz(self.filename)
|
||||
assert isinstance(obj, Instance)
|
||||
self.instance = obj
|
||||
self.samples = self.instance.samples
|
||||
|
||||
@overrides
|
||||
def free(self) -> None:
|
||||
self.instance = None # type: ignore
|
||||
self.samples = None # type: ignore
|
||||
gc.collect()
|
||||
|
||||
@overrides
|
||||
def flush(self) -> None:
|
||||
write_pickle_gz(self.instance, self.filename)
|
||||
|
||||
@overrides
|
||||
def get_samples(self) -> List[Sample]:
|
||||
assert self.instance is not None
|
||||
return self.instance.get_samples()
|
||||
|
||||
@overrides
|
||||
def push_sample(self, sample: Sample) -> None:
|
||||
assert self.instance is not None
|
||||
self.instance.push_sample(sample)
|
||||
|
||||
|
||||
def write_pickle_gz(obj: Any, filename: str) -> None:
|
||||
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
||||
|
||||
@@ -150,7 +150,7 @@ class LearningSolver:
|
||||
# Initialize training sample
|
||||
# -------------------------------------------------------
|
||||
sample = Sample()
|
||||
instance.samples.append(sample)
|
||||
instance.push_sample(sample)
|
||||
|
||||
# Initialize stats
|
||||
# -------------------------------------------------------
|
||||
|
||||
Reference in New Issue
Block a user