diff --git a/miplearn/benchmark.py b/miplearn/benchmark.py index cf23961..1108252 100644 --- a/miplearn/benchmark.py +++ b/miplearn/benchmark.py @@ -37,7 +37,8 @@ class BenchmarkRunner: for (solver_name, solver) in self.solvers.items(): results = solver.parallel_solve(trials, n_jobs=n_jobs, - label="Solve (%s)" % solver_name) + label="Solve (%s)" % solver_name, + output=None) for i in range(len(trials)): idx = (i % len(instances)) + index_offset self._push_result(results[i], diff --git a/miplearn/components/primal.py b/miplearn/components/primal.py index 6c2cc84..8c1c88e 100644 --- a/miplearn/components/primal.py +++ b/miplearn/components/primal.py @@ -53,7 +53,6 @@ class PrimalSolutionComponent(Component): for category in tqdm(features.keys(), desc="Fit (primal)", - disable=not sys.stdout.isatty(), ): x_train = features[category] for label in [0, 1]: @@ -110,7 +109,6 @@ class PrimalSolutionComponent(Component): "Fix one": {}} for instance_idx in tqdm(range(len(instances)), desc="Evaluate (primal)", - disable=not sys.stdout.isatty(), ): instance = instances[instance_idx] solution_actual = instance.solution diff --git a/miplearn/components/relaxation.py b/miplearn/components/relaxation.py index 9320ad2..bbcf0d3 100644 --- a/miplearn/components/relaxation.py +++ b/miplearn/components/relaxation.py @@ -8,10 +8,11 @@ from copy import deepcopy import numpy as np from miplearn.components import classifier_evaluation_dict -from tqdm import tqdm +from tqdm.auto import tqdm from miplearn import Component from miplearn.classifiers.counting import CountingClassifier +from miplearn.extractors import InstanceIterator logger = logging.getLogger(__name__) @@ -60,16 +61,12 @@ class RelaxationComponent(Component): instance.slacks = solver.internal_solver.get_constraint_slacks() def fit(self, training_instances): - training_instances = [instance - for instance in training_instances - if hasattr(instance, "slacks")] logger.debug("Extracting x and y...") x = self.x(training_instances) y = self.y(training_instances) logger.debug("Fitting...") for category in tqdm(x.keys(), - desc="Fit (relaxation)", - disable=not sys.stdout.isatty()): + desc="Fit (relaxation)"): if category not in self.classifiers: self.classifiers[category] = deepcopy(self.classifier_prototype) self.classifiers[category].fit(x[category], y[category]) @@ -80,7 +77,9 @@ class RelaxationComponent(Component): return_constraints=False): x = {} constraints = {} - for instance in instances: + for instance in tqdm(InstanceIterator(instances), + desc="Extract (relaxation:x)", + disable=len(instances) < 5): if constraint_ids is not None: cids = constraint_ids else: @@ -101,7 +100,9 @@ class RelaxationComponent(Component): def y(self, instances): y = {} - for instance in instances: + for instance in tqdm(InstanceIterator(instances), + desc="Extract (relaxation:y)", + disable=len(instances) < 5): for (cid, slack) in instance.slacks.items(): category = instance.get_constraint_category(cid) if category is None: @@ -120,7 +121,7 @@ class RelaxationComponent(Component): if category not in self.classifiers: continue y[category] = [] - #x_cat = np.array(x_cat) + x_cat = np.array(x_cat) proba = self.classifiers[category].predict_proba(x_cat) for i in range(len(proba)): if proba[i][1] >= self.threshold: diff --git a/miplearn/extractors.py b/miplearn/extractors.py index 1766dc5..24d81d4 100644 --- a/miplearn/extractors.py +++ b/miplearn/extractors.py @@ -3,14 +3,41 @@ # Released under the modified BSD license. See COPYING.md for more details. import logging -from abc import ABC, abstractmethod +import pickle +import gzip import numpy as np -from tqdm import tqdm + +from tqdm.auto import tqdm +from abc import ABC, abstractmethod logger = logging.getLogger(__name__) +class InstanceIterator: + def __init__(self, instances): + self.instances = instances + self.current = 0 + + def __iter__(self): + return self + + def __next__(self): + if self.current >= len(self.instances): + raise StopIteration + result = self.instances[self.current] + self.current += 1 + if isinstance(result, str): + logger.info("Read: %s" % result) + if result.endswith(".gz"): + with gzip.GzipFile(result, "rb") as file: + result = pickle.load(file) + else: + with open(result, "rb") as file: + result = pickle.load(file) + return result + + class Extractor(ABC): @abstractmethod def extract(self, instances,): @@ -34,7 +61,7 @@ class Extractor(ABC): class VariableFeaturesExtractor(Extractor): def extract(self, instances): result = {} - for instance in tqdm(instances, + for instance in tqdm(InstanceIterator(instances), desc="Extract (vars)", disable=len(instances) < 5): instance_features = instance.get_instance_features() @@ -59,7 +86,7 @@ class SolutionExtractor(Extractor): def extract(self, instances): result = {} - for instance in tqdm(instances, + for instance in tqdm(InstanceIterator(instances), desc="Extract (solution)", disable=len(instances) < 5): var_split = self.split_variables(instance) @@ -87,7 +114,7 @@ class InstanceFeaturesExtractor(Extractor): instance.get_instance_features(), instance.lp_value, ]) - for instance in instances + for instance in InstanceIterator(instances) ]) @@ -98,8 +125,11 @@ class ObjectiveValueExtractor(Extractor): def extract(self, instances): if self.kind == "lower bound": - return np.array([[instance.lower_bound] for instance in instances]) + return np.array([[instance.lower_bound] + for instance in InstanceIterator(instances)]) if self.kind == "upper bound": - return np.array([[instance.upper_bound] for instance in instances]) + return np.array([[instance.upper_bound] + for instance in InstanceIterator(instances)]) if self.kind == "lp": - return np.array([[instance.lp_value] for instance in instances]) + return np.array([[instance.lp_value] + for instance in InstanceIterator(instances)]) diff --git a/miplearn/solvers/learning.py b/miplearn/solvers/learning.py index 0cf2616..017abc1 100644 --- a/miplearn/solvers/learning.py +++ b/miplearn/solvers/learning.py @@ -6,6 +6,7 @@ import logging import pickle import os import tempfile +import gzip from copy import deepcopy from typing import Optional, List @@ -198,11 +199,18 @@ class LearningSolver: """ filename = None + fileformat = None if isinstance(instance, str): filename = instance logger.info("Reading: %s" % filename) - with open(filename, "rb") as file: - instance = pickle.load(file) + if filename.endswith(".gz"): + fileformat = "pickle-gz" + with gzip.GzipFile(filename, "rb") as file: + instance = pickle.load(file) + else: + fileformat = "pickle" + with open(filename, "rb") as file: + instance = pickle.load(file) if model is None: model = instance.to_model() @@ -260,9 +268,12 @@ class LearningSolver: if len(output) == 0: output_filename = filename logger.info("Writing: %s" % output_filename) - with tempfile.NamedTemporaryFile(delete=False) as tmp: - pickle.dump(instance, tmp) - os.replace(tmp.name, output_filename) + if fileformat == "pickle": + with open(output_filename, "wb") as file: + pickle.dump(instance, file) + else: + with gzip.GzipFile(output_filename, "wb") as file: + pickle.dump(instance, file) return results