Replace instance.samples by instance.get/push_sample

This commit is contained in:
2021-06-29 16:49:24 -05:00
parent a5092cc2b9
commit 80281df8d8
9 changed files with 48 additions and 34 deletions

View File

@@ -28,7 +28,7 @@ class Instance(ABC):
"""
def __init__(self) -> None:
self.samples: List[Sample] = []
self._samples: List[Sample] = []
@abstractmethod
def to_model(self) -> Any:
@@ -189,3 +189,9 @@ class Instance(ABC):
Save any pending changes made to the instance to the underlying data store.
"""
pass
def get_samples(self) -> List[Sample]:
return self._samples
def push_sample(self, sample: Sample) -> None:
self._samples.append(sample)

View File

@@ -10,6 +10,7 @@ from typing import Optional, Any, List, Hashable, cast, IO, TYPE_CHECKING, Dict
from overrides import overrides
from miplearn.features import Sample
from miplearn.instance.base import Instance
if TYPE_CHECKING:
@@ -120,18 +121,26 @@ class PickleGzInstance(Instance):
obj = read_pickle_gz(self.filename)
assert isinstance(obj, Instance)
self.instance = obj
self.samples = self.instance.samples
@overrides
def free(self) -> None:
self.instance = None # type: ignore
self.samples = None # type: ignore
gc.collect()
@overrides
def flush(self) -> None:
write_pickle_gz(self.instance, self.filename)
@overrides
def get_samples(self) -> List[Sample]:
assert self.instance is not None
return self.instance.get_samples()
@overrides
def push_sample(self, sample: Sample) -> None:
assert self.instance is not None
self.instance.push_sample(sample)
def write_pickle_gz(obj: Any, filename: str) -> None:
os.makedirs(os.path.dirname(filename), exist_ok=True)