Train without loading all instances to memory

pull/3/head
Alinson S. Xavier 5 years ago
parent f03cc15b75
commit 388b10c63c

@ -37,7 +37,8 @@ class BenchmarkRunner:
for (solver_name, solver) in self.solvers.items(): for (solver_name, solver) in self.solvers.items():
results = solver.parallel_solve(trials, results = solver.parallel_solve(trials,
n_jobs=n_jobs, n_jobs=n_jobs,
label="Solve (%s)" % solver_name) label="Solve (%s)" % solver_name,
output=None)
for i in range(len(trials)): for i in range(len(trials)):
idx = (i % len(instances)) + index_offset idx = (i % len(instances)) + index_offset
self._push_result(results[i], self._push_result(results[i],

@ -53,7 +53,6 @@ class PrimalSolutionComponent(Component):
for category in tqdm(features.keys(), for category in tqdm(features.keys(),
desc="Fit (primal)", desc="Fit (primal)",
disable=not sys.stdout.isatty(),
): ):
x_train = features[category] x_train = features[category]
for label in [0, 1]: for label in [0, 1]:
@ -110,7 +109,6 @@ class PrimalSolutionComponent(Component):
"Fix one": {}} "Fix one": {}}
for instance_idx in tqdm(range(len(instances)), for instance_idx in tqdm(range(len(instances)),
desc="Evaluate (primal)", desc="Evaluate (primal)",
disable=not sys.stdout.isatty(),
): ):
instance = instances[instance_idx] instance = instances[instance_idx]
solution_actual = instance.solution solution_actual = instance.solution

@ -8,10 +8,11 @@ from copy import deepcopy
import numpy as np import numpy as np
from miplearn.components import classifier_evaluation_dict from miplearn.components import classifier_evaluation_dict
from tqdm import tqdm from tqdm.auto import tqdm
from miplearn import Component from miplearn import Component
from miplearn.classifiers.counting import CountingClassifier from miplearn.classifiers.counting import CountingClassifier
from miplearn.extractors import InstanceIterator
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -60,16 +61,12 @@ class RelaxationComponent(Component):
instance.slacks = solver.internal_solver.get_constraint_slacks() instance.slacks = solver.internal_solver.get_constraint_slacks()
def fit(self, training_instances): def fit(self, training_instances):
training_instances = [instance
for instance in training_instances
if hasattr(instance, "slacks")]
logger.debug("Extracting x and y...") logger.debug("Extracting x and y...")
x = self.x(training_instances) x = self.x(training_instances)
y = self.y(training_instances) y = self.y(training_instances)
logger.debug("Fitting...") logger.debug("Fitting...")
for category in tqdm(x.keys(), for category in tqdm(x.keys(),
desc="Fit (relaxation)", desc="Fit (relaxation)"):
disable=not sys.stdout.isatty()):
if category not in self.classifiers: if category not in self.classifiers:
self.classifiers[category] = deepcopy(self.classifier_prototype) self.classifiers[category] = deepcopy(self.classifier_prototype)
self.classifiers[category].fit(x[category], y[category]) self.classifiers[category].fit(x[category], y[category])
@ -80,7 +77,9 @@ class RelaxationComponent(Component):
return_constraints=False): return_constraints=False):
x = {} x = {}
constraints = {} constraints = {}
for instance in instances: for instance in tqdm(InstanceIterator(instances),
desc="Extract (relaxation:x)",
disable=len(instances) < 5):
if constraint_ids is not None: if constraint_ids is not None:
cids = constraint_ids cids = constraint_ids
else: else:
@ -101,7 +100,9 @@ class RelaxationComponent(Component):
def y(self, instances): def y(self, instances):
y = {} y = {}
for instance in instances: for instance in tqdm(InstanceIterator(instances),
desc="Extract (relaxation:y)",
disable=len(instances) < 5):
for (cid, slack) in instance.slacks.items(): for (cid, slack) in instance.slacks.items():
category = instance.get_constraint_category(cid) category = instance.get_constraint_category(cid)
if category is None: if category is None:
@ -120,7 +121,7 @@ class RelaxationComponent(Component):
if category not in self.classifiers: if category not in self.classifiers:
continue continue
y[category] = [] y[category] = []
#x_cat = np.array(x_cat) x_cat = np.array(x_cat)
proba = self.classifiers[category].predict_proba(x_cat) proba = self.classifiers[category].predict_proba(x_cat)
for i in range(len(proba)): for i in range(len(proba)):
if proba[i][1] >= self.threshold: if proba[i][1] >= self.threshold:

@ -3,14 +3,41 @@
# Released under the modified BSD license. See COPYING.md for more details. # Released under the modified BSD license. See COPYING.md for more details.
import logging import logging
from abc import ABC, abstractmethod import pickle
import gzip
import numpy as np import numpy as np
from tqdm import tqdm
from tqdm.auto import tqdm
from abc import ABC, abstractmethod
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class InstanceIterator:
def __init__(self, instances):
self.instances = instances
self.current = 0
def __iter__(self):
return self
def __next__(self):
if self.current >= len(self.instances):
raise StopIteration
result = self.instances[self.current]
self.current += 1
if isinstance(result, str):
logger.info("Read: %s" % result)
if result.endswith(".gz"):
with gzip.GzipFile(result, "rb") as file:
result = pickle.load(file)
else:
with open(result, "rb") as file:
result = pickle.load(file)
return result
class Extractor(ABC): class Extractor(ABC):
@abstractmethod @abstractmethod
def extract(self, instances,): def extract(self, instances,):
@ -34,7 +61,7 @@ class Extractor(ABC):
class VariableFeaturesExtractor(Extractor): class VariableFeaturesExtractor(Extractor):
def extract(self, instances): def extract(self, instances):
result = {} result = {}
for instance in tqdm(instances, for instance in tqdm(InstanceIterator(instances),
desc="Extract (vars)", desc="Extract (vars)",
disable=len(instances) < 5): disable=len(instances) < 5):
instance_features = instance.get_instance_features() instance_features = instance.get_instance_features()
@ -59,7 +86,7 @@ class SolutionExtractor(Extractor):
def extract(self, instances): def extract(self, instances):
result = {} result = {}
for instance in tqdm(instances, for instance in tqdm(InstanceIterator(instances),
desc="Extract (solution)", desc="Extract (solution)",
disable=len(instances) < 5): disable=len(instances) < 5):
var_split = self.split_variables(instance) var_split = self.split_variables(instance)
@ -87,7 +114,7 @@ class InstanceFeaturesExtractor(Extractor):
instance.get_instance_features(), instance.get_instance_features(),
instance.lp_value, instance.lp_value,
]) ])
for instance in instances for instance in InstanceIterator(instances)
]) ])
@ -98,8 +125,11 @@ class ObjectiveValueExtractor(Extractor):
def extract(self, instances): def extract(self, instances):
if self.kind == "lower bound": if self.kind == "lower bound":
return np.array([[instance.lower_bound] for instance in instances]) return np.array([[instance.lower_bound]
for instance in InstanceIterator(instances)])
if self.kind == "upper bound": if self.kind == "upper bound":
return np.array([[instance.upper_bound] for instance in instances]) return np.array([[instance.upper_bound]
for instance in InstanceIterator(instances)])
if self.kind == "lp": if self.kind == "lp":
return np.array([[instance.lp_value] for instance in instances]) return np.array([[instance.lp_value]
for instance in InstanceIterator(instances)])

@ -6,6 +6,7 @@ import logging
import pickle import pickle
import os import os
import tempfile import tempfile
import gzip
from copy import deepcopy from copy import deepcopy
from typing import Optional, List from typing import Optional, List
@ -198,11 +199,18 @@ class LearningSolver:
""" """
filename = None filename = None
fileformat = None
if isinstance(instance, str): if isinstance(instance, str):
filename = instance filename = instance
logger.info("Reading: %s" % filename) logger.info("Reading: %s" % filename)
with open(filename, "rb") as file: if filename.endswith(".gz"):
instance = pickle.load(file) fileformat = "pickle-gz"
with gzip.GzipFile(filename, "rb") as file:
instance = pickle.load(file)
else:
fileformat = "pickle"
with open(filename, "rb") as file:
instance = pickle.load(file)
if model is None: if model is None:
model = instance.to_model() model = instance.to_model()
@ -260,9 +268,12 @@ class LearningSolver:
if len(output) == 0: if len(output) == 0:
output_filename = filename output_filename = filename
logger.info("Writing: %s" % output_filename) logger.info("Writing: %s" % output_filename)
with tempfile.NamedTemporaryFile(delete=False) as tmp: if fileformat == "pickle":
pickle.dump(instance, tmp) with open(output_filename, "wb") as file:
os.replace(tmp.name, output_filename) pickle.dump(instance, file)
else:
with gzip.GzipFile(output_filename, "wb") as file:
pickle.dump(instance, file)
return results return results

Loading…
Cancel
Save