mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Train without loading all instances to memory
This commit is contained in:
@@ -37,7 +37,8 @@ class BenchmarkRunner:
|
||||
for (solver_name, solver) in self.solvers.items():
|
||||
results = solver.parallel_solve(trials,
|
||||
n_jobs=n_jobs,
|
||||
label="Solve (%s)" % solver_name)
|
||||
label="Solve (%s)" % solver_name,
|
||||
output=None)
|
||||
for i in range(len(trials)):
|
||||
idx = (i % len(instances)) + index_offset
|
||||
self._push_result(results[i],
|
||||
|
||||
@@ -53,7 +53,6 @@ class PrimalSolutionComponent(Component):
|
||||
|
||||
for category in tqdm(features.keys(),
|
||||
desc="Fit (primal)",
|
||||
disable=not sys.stdout.isatty(),
|
||||
):
|
||||
x_train = features[category]
|
||||
for label in [0, 1]:
|
||||
@@ -110,7 +109,6 @@ class PrimalSolutionComponent(Component):
|
||||
"Fix one": {}}
|
||||
for instance_idx in tqdm(range(len(instances)),
|
||||
desc="Evaluate (primal)",
|
||||
disable=not sys.stdout.isatty(),
|
||||
):
|
||||
instance = instances[instance_idx]
|
||||
solution_actual = instance.solution
|
||||
|
||||
@@ -8,10 +8,11 @@ from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
from miplearn.components import classifier_evaluation_dict
|
||||
from tqdm import tqdm
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from miplearn import Component
|
||||
from miplearn.classifiers.counting import CountingClassifier
|
||||
from miplearn.extractors import InstanceIterator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -60,16 +61,12 @@ class RelaxationComponent(Component):
|
||||
instance.slacks = solver.internal_solver.get_constraint_slacks()
|
||||
|
||||
def fit(self, training_instances):
|
||||
training_instances = [instance
|
||||
for instance in training_instances
|
||||
if hasattr(instance, "slacks")]
|
||||
logger.debug("Extracting x and y...")
|
||||
x = self.x(training_instances)
|
||||
y = self.y(training_instances)
|
||||
logger.debug("Fitting...")
|
||||
for category in tqdm(x.keys(),
|
||||
desc="Fit (relaxation)",
|
||||
disable=not sys.stdout.isatty()):
|
||||
desc="Fit (relaxation)"):
|
||||
if category not in self.classifiers:
|
||||
self.classifiers[category] = deepcopy(self.classifier_prototype)
|
||||
self.classifiers[category].fit(x[category], y[category])
|
||||
@@ -80,7 +77,9 @@ class RelaxationComponent(Component):
|
||||
return_constraints=False):
|
||||
x = {}
|
||||
constraints = {}
|
||||
for instance in instances:
|
||||
for instance in tqdm(InstanceIterator(instances),
|
||||
desc="Extract (relaxation:x)",
|
||||
disable=len(instances) < 5):
|
||||
if constraint_ids is not None:
|
||||
cids = constraint_ids
|
||||
else:
|
||||
@@ -101,7 +100,9 @@ class RelaxationComponent(Component):
|
||||
|
||||
def y(self, instances):
|
||||
y = {}
|
||||
for instance in instances:
|
||||
for instance in tqdm(InstanceIterator(instances),
|
||||
desc="Extract (relaxation:y)",
|
||||
disable=len(instances) < 5):
|
||||
for (cid, slack) in instance.slacks.items():
|
||||
category = instance.get_constraint_category(cid)
|
||||
if category is None:
|
||||
@@ -120,7 +121,7 @@ class RelaxationComponent(Component):
|
||||
if category not in self.classifiers:
|
||||
continue
|
||||
y[category] = []
|
||||
#x_cat = np.array(x_cat)
|
||||
x_cat = np.array(x_cat)
|
||||
proba = self.classifiers[category].predict_proba(x_cat)
|
||||
for i in range(len(proba)):
|
||||
if proba[i][1] >= self.threshold:
|
||||
|
||||
@@ -3,14 +3,41 @@
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
import pickle
|
||||
import gzip
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InstanceIterator:
|
||||
def __init__(self, instances):
|
||||
self.instances = instances
|
||||
self.current = 0
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.current >= len(self.instances):
|
||||
raise StopIteration
|
||||
result = self.instances[self.current]
|
||||
self.current += 1
|
||||
if isinstance(result, str):
|
||||
logger.info("Read: %s" % result)
|
||||
if result.endswith(".gz"):
|
||||
with gzip.GzipFile(result, "rb") as file:
|
||||
result = pickle.load(file)
|
||||
else:
|
||||
with open(result, "rb") as file:
|
||||
result = pickle.load(file)
|
||||
return result
|
||||
|
||||
|
||||
class Extractor(ABC):
|
||||
@abstractmethod
|
||||
def extract(self, instances,):
|
||||
@@ -34,7 +61,7 @@ class Extractor(ABC):
|
||||
class VariableFeaturesExtractor(Extractor):
|
||||
def extract(self, instances):
|
||||
result = {}
|
||||
for instance in tqdm(instances,
|
||||
for instance in tqdm(InstanceIterator(instances),
|
||||
desc="Extract (vars)",
|
||||
disable=len(instances) < 5):
|
||||
instance_features = instance.get_instance_features()
|
||||
@@ -59,7 +86,7 @@ class SolutionExtractor(Extractor):
|
||||
|
||||
def extract(self, instances):
|
||||
result = {}
|
||||
for instance in tqdm(instances,
|
||||
for instance in tqdm(InstanceIterator(instances),
|
||||
desc="Extract (solution)",
|
||||
disable=len(instances) < 5):
|
||||
var_split = self.split_variables(instance)
|
||||
@@ -87,7 +114,7 @@ class InstanceFeaturesExtractor(Extractor):
|
||||
instance.get_instance_features(),
|
||||
instance.lp_value,
|
||||
])
|
||||
for instance in instances
|
||||
for instance in InstanceIterator(instances)
|
||||
])
|
||||
|
||||
|
||||
@@ -98,8 +125,11 @@ class ObjectiveValueExtractor(Extractor):
|
||||
|
||||
def extract(self, instances):
|
||||
if self.kind == "lower bound":
|
||||
return np.array([[instance.lower_bound] for instance in instances])
|
||||
return np.array([[instance.lower_bound]
|
||||
for instance in InstanceIterator(instances)])
|
||||
if self.kind == "upper bound":
|
||||
return np.array([[instance.upper_bound] for instance in instances])
|
||||
return np.array([[instance.upper_bound]
|
||||
for instance in InstanceIterator(instances)])
|
||||
if self.kind == "lp":
|
||||
return np.array([[instance.lp_value] for instance in instances])
|
||||
return np.array([[instance.lp_value]
|
||||
for instance in InstanceIterator(instances)])
|
||||
|
||||
@@ -6,6 +6,7 @@ import logging
|
||||
import pickle
|
||||
import os
|
||||
import tempfile
|
||||
import gzip
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Optional, List
|
||||
@@ -198,9 +199,16 @@ class LearningSolver:
|
||||
"""
|
||||
|
||||
filename = None
|
||||
fileformat = None
|
||||
if isinstance(instance, str):
|
||||
filename = instance
|
||||
logger.info("Reading: %s" % filename)
|
||||
if filename.endswith(".gz"):
|
||||
fileformat = "pickle-gz"
|
||||
with gzip.GzipFile(filename, "rb") as file:
|
||||
instance = pickle.load(file)
|
||||
else:
|
||||
fileformat = "pickle"
|
||||
with open(filename, "rb") as file:
|
||||
instance = pickle.load(file)
|
||||
|
||||
@@ -260,9 +268,12 @@ class LearningSolver:
|
||||
if len(output) == 0:
|
||||
output_filename = filename
|
||||
logger.info("Writing: %s" % output_filename)
|
||||
with tempfile.NamedTemporaryFile(delete=False) as tmp:
|
||||
pickle.dump(instance, tmp)
|
||||
os.replace(tmp.name, output_filename)
|
||||
if fileformat == "pickle":
|
||||
with open(output_filename, "wb") as file:
|
||||
pickle.dump(instance, file)
|
||||
else:
|
||||
with gzip.GzipFile(output_filename, "wb") as file:
|
||||
pickle.dump(instance, file)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
Reference in New Issue
Block a user