Train without loading all instances to memory

This commit is contained in:
2020-12-04 09:37:41 -06:00
parent f03cc15b75
commit 388b10c63c
5 changed files with 66 additions and 25 deletions

View File

@@ -3,14 +3,41 @@
# Released under the modified BSD license. See COPYING.md for more details.
import logging
from abc import ABC, abstractmethod
import pickle
import gzip
import numpy as np
from tqdm import tqdm
from tqdm.auto import tqdm
from abc import ABC, abstractmethod
logger = logging.getLogger(__name__)
class InstanceIterator:
def __init__(self, instances):
self.instances = instances
self.current = 0
def __iter__(self):
return self
def __next__(self):
if self.current >= len(self.instances):
raise StopIteration
result = self.instances[self.current]
self.current += 1
if isinstance(result, str):
logger.info("Read: %s" % result)
if result.endswith(".gz"):
with gzip.GzipFile(result, "rb") as file:
result = pickle.load(file)
else:
with open(result, "rb") as file:
result = pickle.load(file)
return result
class Extractor(ABC):
@abstractmethod
def extract(self, instances,):
@@ -34,7 +61,7 @@ class Extractor(ABC):
class VariableFeaturesExtractor(Extractor):
def extract(self, instances):
result = {}
for instance in tqdm(instances,
for instance in tqdm(InstanceIterator(instances),
desc="Extract (vars)",
disable=len(instances) < 5):
instance_features = instance.get_instance_features()
@@ -59,7 +86,7 @@ class SolutionExtractor(Extractor):
def extract(self, instances):
result = {}
for instance in tqdm(instances,
for instance in tqdm(InstanceIterator(instances),
desc="Extract (solution)",
disable=len(instances) < 5):
var_split = self.split_variables(instance)
@@ -87,7 +114,7 @@ class InstanceFeaturesExtractor(Extractor):
instance.get_instance_features(),
instance.lp_value,
])
for instance in instances
for instance in InstanceIterator(instances)
])
@@ -98,8 +125,11 @@ class ObjectiveValueExtractor(Extractor):
def extract(self, instances):
if self.kind == "lower bound":
return np.array([[instance.lower_bound] for instance in instances])
return np.array([[instance.lower_bound]
for instance in InstanceIterator(instances)])
if self.kind == "upper bound":
return np.array([[instance.upper_bound] for instance in instances])
return np.array([[instance.upper_bound]
for instance in InstanceIterator(instances)])
if self.kind == "lp":
return np.array([[instance.lp_value] for instance in instances])
return np.array([[instance.lp_value]
for instance in InstanceIterator(instances)])