Train without loading all instances to memory

pull/3/head
Alinson S. Xavier 5 years ago
parent f03cc15b75
commit 388b10c63c

@ -37,7 +37,8 @@ class BenchmarkRunner:
for (solver_name, solver) in self.solvers.items():
results = solver.parallel_solve(trials,
n_jobs=n_jobs,
label="Solve (%s)" % solver_name)
label="Solve (%s)" % solver_name,
output=None)
for i in range(len(trials)):
idx = (i % len(instances)) + index_offset
self._push_result(results[i],

@ -53,7 +53,6 @@ class PrimalSolutionComponent(Component):
for category in tqdm(features.keys(),
desc="Fit (primal)",
disable=not sys.stdout.isatty(),
):
x_train = features[category]
for label in [0, 1]:
@ -110,7 +109,6 @@ class PrimalSolutionComponent(Component):
"Fix one": {}}
for instance_idx in tqdm(range(len(instances)),
desc="Evaluate (primal)",
disable=not sys.stdout.isatty(),
):
instance = instances[instance_idx]
solution_actual = instance.solution

@ -8,10 +8,11 @@ from copy import deepcopy
import numpy as np
from miplearn.components import classifier_evaluation_dict
from tqdm import tqdm
from tqdm.auto import tqdm
from miplearn import Component
from miplearn.classifiers.counting import CountingClassifier
from miplearn.extractors import InstanceIterator
logger = logging.getLogger(__name__)
@ -60,16 +61,12 @@ class RelaxationComponent(Component):
instance.slacks = solver.internal_solver.get_constraint_slacks()
def fit(self, training_instances):
training_instances = [instance
for instance in training_instances
if hasattr(instance, "slacks")]
logger.debug("Extracting x and y...")
x = self.x(training_instances)
y = self.y(training_instances)
logger.debug("Fitting...")
for category in tqdm(x.keys(),
desc="Fit (relaxation)",
disable=not sys.stdout.isatty()):
desc="Fit (relaxation)"):
if category not in self.classifiers:
self.classifiers[category] = deepcopy(self.classifier_prototype)
self.classifiers[category].fit(x[category], y[category])
@ -80,7 +77,9 @@ class RelaxationComponent(Component):
return_constraints=False):
x = {}
constraints = {}
for instance in instances:
for instance in tqdm(InstanceIterator(instances),
desc="Extract (relaxation:x)",
disable=len(instances) < 5):
if constraint_ids is not None:
cids = constraint_ids
else:
@ -101,7 +100,9 @@ class RelaxationComponent(Component):
def y(self, instances):
y = {}
for instance in instances:
for instance in tqdm(InstanceIterator(instances),
desc="Extract (relaxation:y)",
disable=len(instances) < 5):
for (cid, slack) in instance.slacks.items():
category = instance.get_constraint_category(cid)
if category is None:
@ -120,7 +121,7 @@ class RelaxationComponent(Component):
if category not in self.classifiers:
continue
y[category] = []
#x_cat = np.array(x_cat)
x_cat = np.array(x_cat)
proba = self.classifiers[category].predict_proba(x_cat)
for i in range(len(proba)):
if proba[i][1] >= self.threshold:

@ -3,14 +3,41 @@
# Released under the modified BSD license. See COPYING.md for more details.
import logging
from abc import ABC, abstractmethod
import pickle
import gzip
import numpy as np
from tqdm import tqdm
from tqdm.auto import tqdm
from abc import ABC, abstractmethod
logger = logging.getLogger(__name__)
class InstanceIterator:
def __init__(self, instances):
self.instances = instances
self.current = 0
def __iter__(self):
return self
def __next__(self):
if self.current >= len(self.instances):
raise StopIteration
result = self.instances[self.current]
self.current += 1
if isinstance(result, str):
logger.info("Read: %s" % result)
if result.endswith(".gz"):
with gzip.GzipFile(result, "rb") as file:
result = pickle.load(file)
else:
with open(result, "rb") as file:
result = pickle.load(file)
return result
class Extractor(ABC):
@abstractmethod
def extract(self, instances,):
@ -34,7 +61,7 @@ class Extractor(ABC):
class VariableFeaturesExtractor(Extractor):
def extract(self, instances):
result = {}
for instance in tqdm(instances,
for instance in tqdm(InstanceIterator(instances),
desc="Extract (vars)",
disable=len(instances) < 5):
instance_features = instance.get_instance_features()
@ -59,7 +86,7 @@ class SolutionExtractor(Extractor):
def extract(self, instances):
result = {}
for instance in tqdm(instances,
for instance in tqdm(InstanceIterator(instances),
desc="Extract (solution)",
disable=len(instances) < 5):
var_split = self.split_variables(instance)
@ -87,7 +114,7 @@ class InstanceFeaturesExtractor(Extractor):
instance.get_instance_features(),
instance.lp_value,
])
for instance in instances
for instance in InstanceIterator(instances)
])
@ -98,8 +125,11 @@ class ObjectiveValueExtractor(Extractor):
def extract(self, instances):
if self.kind == "lower bound":
return np.array([[instance.lower_bound] for instance in instances])
return np.array([[instance.lower_bound]
for instance in InstanceIterator(instances)])
if self.kind == "upper bound":
return np.array([[instance.upper_bound] for instance in instances])
return np.array([[instance.upper_bound]
for instance in InstanceIterator(instances)])
if self.kind == "lp":
return np.array([[instance.lp_value] for instance in instances])
return np.array([[instance.lp_value]
for instance in InstanceIterator(instances)])

@ -6,6 +6,7 @@ import logging
import pickle
import os
import tempfile
import gzip
from copy import deepcopy
from typing import Optional, List
@ -198,9 +199,16 @@ class LearningSolver:
"""
filename = None
fileformat = None
if isinstance(instance, str):
filename = instance
logger.info("Reading: %s" % filename)
if filename.endswith(".gz"):
fileformat = "pickle-gz"
with gzip.GzipFile(filename, "rb") as file:
instance = pickle.load(file)
else:
fileformat = "pickle"
with open(filename, "rb") as file:
instance = pickle.load(file)
@ -260,9 +268,12 @@ class LearningSolver:
if len(output) == 0:
output_filename = filename
logger.info("Writing: %s" % output_filename)
with tempfile.NamedTemporaryFile(delete=False) as tmp:
pickle.dump(instance, tmp)
os.replace(tmp.name, output_filename)
if fileformat == "pickle":
with open(output_filename, "wb") as file:
pickle.dump(instance, file)
else:
with gzip.GzipFile(output_filename, "wb") as file:
pickle.dump(instance, file)
return results

Loading…
Cancel
Save