PickleGzInstance: Replace implicit load by load/free methods

master
Alinson S. Xavier 5 years ago
parent f495297168
commit 856b595d5e

@ -125,7 +125,7 @@ class Component:
x_combined: Dict = {} x_combined: Dict = {}
y_combined: Dict = {} y_combined: Dict = {}
for instance in instances: for instance in instances:
assert isinstance(instance, Instance) instance.load()
for sample in instance.training_data: for sample in instance.training_data:
xy = self.sample_xy(instance, sample) xy = self.sample_xy(instance, sample)
if xy is None: if xy is None:
@ -137,6 +137,7 @@ class Component:
y_combined[cat] = [] y_combined[cat] = []
x_combined[cat] += x_sample[cat] x_combined[cat] += x_sample[cat]
y_combined[cat] += y_sample[cat] y_combined[cat] += y_sample[cat]
instance.free()
return x_combined, y_combined return x_combined, y_combined
def fit( def fit(
@ -209,8 +210,10 @@ class Component:
def evaluate(self, instances: List[Instance]) -> List: def evaluate(self, instances: List[Instance]) -> List:
ev = [] ev = []
for instance in instances: for instance in instances:
instance.load()
for sample in instance.training_data: for sample in instance.training_data:
ev += [self.sample_evaluate(instance, sample)] ev += [self.sample_evaluate(instance, sample)]
instance.free()
return ev return ev
def sample_evaluate( def sample_evaluate(

@ -104,10 +104,12 @@ class DynamicConstraintsComponent(Component):
def fit(self, training_instances: List["Instance"]) -> None: def fit(self, training_instances: List["Instance"]) -> None:
collected_cids = set() collected_cids = set()
for instance in training_instances: for instance in training_instances:
instance.load()
for sample in instance.training_data: for sample in instance.training_data:
if getattr(sample, self.attr) is None: if getattr(sample, self.attr) is None:
continue continue
collected_cids |= getattr(sample, self.attr) collected_cids |= getattr(sample, self.attr)
instance.free()
self.known_cids.clear() self.known_cids.clear()
self.known_cids.extend(sorted(collected_cids)) self.known_cids.extend(sorted(collected_cids))
super().fit(training_instances) super().fit(training_instances)

@ -158,6 +158,12 @@ class Instance(ABC):
def build_user_cut(self, model: Any, violation: Hashable) -> Any: def build_user_cut(self, model: Any, violation: Hashable) -> Any:
return None return None
def load(self) -> None:
pass
def free(self) -> None:
pass
def flush(self) -> None: def flush(self) -> None:
""" """
Save any pending changes made to the instance to the underlying data store. Save any pending changes made to the instance to the underlying data store.

@ -5,23 +5,13 @@
import gzip import gzip
import os import os
import pickle import pickle
import gc
from typing import Optional, Any, List, Hashable, cast, IO, Callable from typing import Optional, Any, List, Hashable, cast, IO, Callable
from miplearn.instance.base import logger, Instance from miplearn.instance.base import logger, Instance
from miplearn.types import VarIndex from miplearn.types import VarIndex
def lazy_load(func: Callable) -> Callable:
def inner(self: Any, *args: Any) -> Any:
if self.instance is None:
self.instance = self._load()
self.features = self.instance.features
self.training_data = self.instance.training_data
return func(self, *args)
return inner
class PickleGzInstance(Instance): class PickleGzInstance(Instance):
""" """
An instance backed by a gzipped pickle file. An instance backed by a gzipped pickle file.
@ -41,22 +31,18 @@ class PickleGzInstance(Instance):
self.instance: Optional[Instance] = None self.instance: Optional[Instance] = None
self.filename: str = filename self.filename: str = filename
@lazy_load
def to_model(self) -> Any: def to_model(self) -> Any:
assert self.instance is not None assert self.instance is not None
return self.instance.to_model() return self.instance.to_model()
@lazy_load
def get_instance_features(self) -> List[float]: def get_instance_features(self) -> List[float]:
assert self.instance is not None assert self.instance is not None
return self.instance.get_instance_features() return self.instance.get_instance_features()
@lazy_load
def get_variable_features(self, var_name: str, index: VarIndex) -> List[float]: def get_variable_features(self, var_name: str, index: VarIndex) -> List[float]:
assert self.instance is not None assert self.instance is not None
return self.instance.get_variable_features(var_name, index) return self.instance.get_variable_features(var_name, index)
@lazy_load
def get_variable_category( def get_variable_category(
self, self,
var_name: str, var_name: str,
@ -65,55 +51,55 @@ class PickleGzInstance(Instance):
assert self.instance is not None assert self.instance is not None
return self.instance.get_variable_category(var_name, index) return self.instance.get_variable_category(var_name, index)
@lazy_load
def get_constraint_features(self, cid: str) -> Optional[List[float]]: def get_constraint_features(self, cid: str) -> Optional[List[float]]:
assert self.instance is not None assert self.instance is not None
return self.instance.get_constraint_features(cid) return self.instance.get_constraint_features(cid)
@lazy_load
def get_constraint_category(self, cid: str) -> Optional[Hashable]: def get_constraint_category(self, cid: str) -> Optional[Hashable]:
assert self.instance is not None assert self.instance is not None
return self.instance.get_constraint_category(cid) return self.instance.get_constraint_category(cid)
@lazy_load
def has_static_lazy_constraints(self) -> bool: def has_static_lazy_constraints(self) -> bool:
assert self.instance is not None assert self.instance is not None
return self.instance.has_static_lazy_constraints() return self.instance.has_static_lazy_constraints()
@lazy_load
def has_dynamic_lazy_constraints(self) -> bool: def has_dynamic_lazy_constraints(self) -> bool:
assert self.instance is not None assert self.instance is not None
return self.instance.has_dynamic_lazy_constraints() return self.instance.has_dynamic_lazy_constraints()
@lazy_load
def is_constraint_lazy(self, cid: str) -> bool: def is_constraint_lazy(self, cid: str) -> bool:
assert self.instance is not None assert self.instance is not None
return self.instance.is_constraint_lazy(cid) return self.instance.is_constraint_lazy(cid)
@lazy_load
def find_violated_lazy_constraints(self, model: Any) -> List[Hashable]: def find_violated_lazy_constraints(self, model: Any) -> List[Hashable]:
assert self.instance is not None assert self.instance is not None
return self.instance.find_violated_lazy_constraints(model) return self.instance.find_violated_lazy_constraints(model)
@lazy_load
def build_lazy_constraint(self, model: Any, violation: Hashable) -> Any: def build_lazy_constraint(self, model: Any, violation: Hashable) -> Any:
assert self.instance is not None assert self.instance is not None
return self.instance.build_lazy_constraint(model, violation) return self.instance.build_lazy_constraint(model, violation)
@lazy_load
def find_violated_user_cuts(self, model: Any) -> List[Hashable]: def find_violated_user_cuts(self, model: Any) -> List[Hashable]:
assert self.instance is not None assert self.instance is not None
return self.instance.find_violated_user_cuts(model) return self.instance.find_violated_user_cuts(model)
@lazy_load
def build_user_cut(self, model: Any, violation: Hashable) -> Any: def build_user_cut(self, model: Any, violation: Hashable) -> Any:
assert self.instance is not None assert self.instance is not None
return self.instance.build_user_cut(model, violation) return self.instance.build_user_cut(model, violation)
def _load(self) -> Instance: def load(self) -> None:
if self.instance is None:
obj = read_pickle_gz(self.filename) obj = read_pickle_gz(self.filename)
assert isinstance(obj, Instance) assert isinstance(obj, Instance)
return obj self.instance = obj
self.features = self.instance.features
self.training_data = self.instance.training_data
def free(self) -> None:
self.instance = None # type: ignore
self.features = None # type: ignore
self.training_data = None # type: ignore
gc.collect()
def flush(self) -> None: def flush(self) -> None:
write_pickle_gz(self.instance, self.filename) write_pickle_gz(self.instance, self.filename)

@ -127,6 +127,7 @@ class LearningSolver:
# Generate model # Generate model
# ------------------------------------------------------- # -------------------------------------------------------
instance.load()
if model is None: if model is None:
with _RedirectOutput([]): with _RedirectOutput([]):
model = instance.to_model() model = instance.to_model()

@ -13,4 +13,5 @@ def test_usage() -> None:
file = tempfile.NamedTemporaryFile() file = tempfile.NamedTemporaryFile()
write_pickle_gz(original, file.name) write_pickle_gz(original, file.name)
pickled = PickleGzInstance(file.name) pickled = PickleGzInstance(file.name)
pickled.load()
assert pickled.to_model() is not None assert pickled.to_model() is not None

Loading…
Cancel
Save