mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Train without loading all instances to memory
This commit is contained in:
@@ -37,7 +37,8 @@ class BenchmarkRunner:
|
|||||||
for (solver_name, solver) in self.solvers.items():
|
for (solver_name, solver) in self.solvers.items():
|
||||||
results = solver.parallel_solve(trials,
|
results = solver.parallel_solve(trials,
|
||||||
n_jobs=n_jobs,
|
n_jobs=n_jobs,
|
||||||
label="Solve (%s)" % solver_name)
|
label="Solve (%s)" % solver_name,
|
||||||
|
output=None)
|
||||||
for i in range(len(trials)):
|
for i in range(len(trials)):
|
||||||
idx = (i % len(instances)) + index_offset
|
idx = (i % len(instances)) + index_offset
|
||||||
self._push_result(results[i],
|
self._push_result(results[i],
|
||||||
|
|||||||
@@ -53,7 +53,6 @@ class PrimalSolutionComponent(Component):
|
|||||||
|
|
||||||
for category in tqdm(features.keys(),
|
for category in tqdm(features.keys(),
|
||||||
desc="Fit (primal)",
|
desc="Fit (primal)",
|
||||||
disable=not sys.stdout.isatty(),
|
|
||||||
):
|
):
|
||||||
x_train = features[category]
|
x_train = features[category]
|
||||||
for label in [0, 1]:
|
for label in [0, 1]:
|
||||||
@@ -110,7 +109,6 @@ class PrimalSolutionComponent(Component):
|
|||||||
"Fix one": {}}
|
"Fix one": {}}
|
||||||
for instance_idx in tqdm(range(len(instances)),
|
for instance_idx in tqdm(range(len(instances)),
|
||||||
desc="Evaluate (primal)",
|
desc="Evaluate (primal)",
|
||||||
disable=not sys.stdout.isatty(),
|
|
||||||
):
|
):
|
||||||
instance = instances[instance_idx]
|
instance = instances[instance_idx]
|
||||||
solution_actual = instance.solution
|
solution_actual = instance.solution
|
||||||
|
|||||||
@@ -8,10 +8,11 @@ from copy import deepcopy
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from miplearn.components import classifier_evaluation_dict
|
from miplearn.components import classifier_evaluation_dict
|
||||||
from tqdm import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
from miplearn import Component
|
from miplearn import Component
|
||||||
from miplearn.classifiers.counting import CountingClassifier
|
from miplearn.classifiers.counting import CountingClassifier
|
||||||
|
from miplearn.extractors import InstanceIterator
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -60,16 +61,12 @@ class RelaxationComponent(Component):
|
|||||||
instance.slacks = solver.internal_solver.get_constraint_slacks()
|
instance.slacks = solver.internal_solver.get_constraint_slacks()
|
||||||
|
|
||||||
def fit(self, training_instances):
|
def fit(self, training_instances):
|
||||||
training_instances = [instance
|
|
||||||
for instance in training_instances
|
|
||||||
if hasattr(instance, "slacks")]
|
|
||||||
logger.debug("Extracting x and y...")
|
logger.debug("Extracting x and y...")
|
||||||
x = self.x(training_instances)
|
x = self.x(training_instances)
|
||||||
y = self.y(training_instances)
|
y = self.y(training_instances)
|
||||||
logger.debug("Fitting...")
|
logger.debug("Fitting...")
|
||||||
for category in tqdm(x.keys(),
|
for category in tqdm(x.keys(),
|
||||||
desc="Fit (relaxation)",
|
desc="Fit (relaxation)"):
|
||||||
disable=not sys.stdout.isatty()):
|
|
||||||
if category not in self.classifiers:
|
if category not in self.classifiers:
|
||||||
self.classifiers[category] = deepcopy(self.classifier_prototype)
|
self.classifiers[category] = deepcopy(self.classifier_prototype)
|
||||||
self.classifiers[category].fit(x[category], y[category])
|
self.classifiers[category].fit(x[category], y[category])
|
||||||
@@ -80,7 +77,9 @@ class RelaxationComponent(Component):
|
|||||||
return_constraints=False):
|
return_constraints=False):
|
||||||
x = {}
|
x = {}
|
||||||
constraints = {}
|
constraints = {}
|
||||||
for instance in instances:
|
for instance in tqdm(InstanceIterator(instances),
|
||||||
|
desc="Extract (relaxation:x)",
|
||||||
|
disable=len(instances) < 5):
|
||||||
if constraint_ids is not None:
|
if constraint_ids is not None:
|
||||||
cids = constraint_ids
|
cids = constraint_ids
|
||||||
else:
|
else:
|
||||||
@@ -101,7 +100,9 @@ class RelaxationComponent(Component):
|
|||||||
|
|
||||||
def y(self, instances):
|
def y(self, instances):
|
||||||
y = {}
|
y = {}
|
||||||
for instance in instances:
|
for instance in tqdm(InstanceIterator(instances),
|
||||||
|
desc="Extract (relaxation:y)",
|
||||||
|
disable=len(instances) < 5):
|
||||||
for (cid, slack) in instance.slacks.items():
|
for (cid, slack) in instance.slacks.items():
|
||||||
category = instance.get_constraint_category(cid)
|
category = instance.get_constraint_category(cid)
|
||||||
if category is None:
|
if category is None:
|
||||||
@@ -120,7 +121,7 @@ class RelaxationComponent(Component):
|
|||||||
if category not in self.classifiers:
|
if category not in self.classifiers:
|
||||||
continue
|
continue
|
||||||
y[category] = []
|
y[category] = []
|
||||||
#x_cat = np.array(x_cat)
|
x_cat = np.array(x_cat)
|
||||||
proba = self.classifiers[category].predict_proba(x_cat)
|
proba = self.classifiers[category].predict_proba(x_cat)
|
||||||
for i in range(len(proba)):
|
for i in range(len(proba)):
|
||||||
if proba[i][1] >= self.threshold:
|
if proba[i][1] >= self.threshold:
|
||||||
|
|||||||
@@ -3,14 +3,41 @@
|
|||||||
# Released under the modified BSD license. See COPYING.md for more details.
|
# Released under the modified BSD license. See COPYING.md for more details.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
import pickle
|
||||||
|
import gzip
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tqdm import tqdm
|
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class InstanceIterator:
|
||||||
|
def __init__(self, instances):
|
||||||
|
self.instances = instances
|
||||||
|
self.current = 0
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
if self.current >= len(self.instances):
|
||||||
|
raise StopIteration
|
||||||
|
result = self.instances[self.current]
|
||||||
|
self.current += 1
|
||||||
|
if isinstance(result, str):
|
||||||
|
logger.info("Read: %s" % result)
|
||||||
|
if result.endswith(".gz"):
|
||||||
|
with gzip.GzipFile(result, "rb") as file:
|
||||||
|
result = pickle.load(file)
|
||||||
|
else:
|
||||||
|
with open(result, "rb") as file:
|
||||||
|
result = pickle.load(file)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class Extractor(ABC):
|
class Extractor(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def extract(self, instances,):
|
def extract(self, instances,):
|
||||||
@@ -34,7 +61,7 @@ class Extractor(ABC):
|
|||||||
class VariableFeaturesExtractor(Extractor):
|
class VariableFeaturesExtractor(Extractor):
|
||||||
def extract(self, instances):
|
def extract(self, instances):
|
||||||
result = {}
|
result = {}
|
||||||
for instance in tqdm(instances,
|
for instance in tqdm(InstanceIterator(instances),
|
||||||
desc="Extract (vars)",
|
desc="Extract (vars)",
|
||||||
disable=len(instances) < 5):
|
disable=len(instances) < 5):
|
||||||
instance_features = instance.get_instance_features()
|
instance_features = instance.get_instance_features()
|
||||||
@@ -59,7 +86,7 @@ class SolutionExtractor(Extractor):
|
|||||||
|
|
||||||
def extract(self, instances):
|
def extract(self, instances):
|
||||||
result = {}
|
result = {}
|
||||||
for instance in tqdm(instances,
|
for instance in tqdm(InstanceIterator(instances),
|
||||||
desc="Extract (solution)",
|
desc="Extract (solution)",
|
||||||
disable=len(instances) < 5):
|
disable=len(instances) < 5):
|
||||||
var_split = self.split_variables(instance)
|
var_split = self.split_variables(instance)
|
||||||
@@ -87,7 +114,7 @@ class InstanceFeaturesExtractor(Extractor):
|
|||||||
instance.get_instance_features(),
|
instance.get_instance_features(),
|
||||||
instance.lp_value,
|
instance.lp_value,
|
||||||
])
|
])
|
||||||
for instance in instances
|
for instance in InstanceIterator(instances)
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|
||||||
@@ -98,8 +125,11 @@ class ObjectiveValueExtractor(Extractor):
|
|||||||
|
|
||||||
def extract(self, instances):
|
def extract(self, instances):
|
||||||
if self.kind == "lower bound":
|
if self.kind == "lower bound":
|
||||||
return np.array([[instance.lower_bound] for instance in instances])
|
return np.array([[instance.lower_bound]
|
||||||
|
for instance in InstanceIterator(instances)])
|
||||||
if self.kind == "upper bound":
|
if self.kind == "upper bound":
|
||||||
return np.array([[instance.upper_bound] for instance in instances])
|
return np.array([[instance.upper_bound]
|
||||||
|
for instance in InstanceIterator(instances)])
|
||||||
if self.kind == "lp":
|
if self.kind == "lp":
|
||||||
return np.array([[instance.lp_value] for instance in instances])
|
return np.array([[instance.lp_value]
|
||||||
|
for instance in InstanceIterator(instances)])
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import logging
|
|||||||
import pickle
|
import pickle
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import gzip
|
||||||
|
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
@@ -198,11 +199,18 @@ class LearningSolver:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
filename = None
|
filename = None
|
||||||
|
fileformat = None
|
||||||
if isinstance(instance, str):
|
if isinstance(instance, str):
|
||||||
filename = instance
|
filename = instance
|
||||||
logger.info("Reading: %s" % filename)
|
logger.info("Reading: %s" % filename)
|
||||||
with open(filename, "rb") as file:
|
if filename.endswith(".gz"):
|
||||||
instance = pickle.load(file)
|
fileformat = "pickle-gz"
|
||||||
|
with gzip.GzipFile(filename, "rb") as file:
|
||||||
|
instance = pickle.load(file)
|
||||||
|
else:
|
||||||
|
fileformat = "pickle"
|
||||||
|
with open(filename, "rb") as file:
|
||||||
|
instance = pickle.load(file)
|
||||||
|
|
||||||
if model is None:
|
if model is None:
|
||||||
model = instance.to_model()
|
model = instance.to_model()
|
||||||
@@ -260,9 +268,12 @@ class LearningSolver:
|
|||||||
if len(output) == 0:
|
if len(output) == 0:
|
||||||
output_filename = filename
|
output_filename = filename
|
||||||
logger.info("Writing: %s" % output_filename)
|
logger.info("Writing: %s" % output_filename)
|
||||||
with tempfile.NamedTemporaryFile(delete=False) as tmp:
|
if fileformat == "pickle":
|
||||||
pickle.dump(instance, tmp)
|
with open(output_filename, "wb") as file:
|
||||||
os.replace(tmp.name, output_filename)
|
pickle.dump(instance, file)
|
||||||
|
else:
|
||||||
|
with gzip.GzipFile(output_filename, "wb") as file:
|
||||||
|
pickle.dump(instance, file)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user