diff --git a/miplearn/__init__.py b/miplearn/__init__.py index 9c1c350..d7e6a9a 100644 --- a/miplearn/__init__.py +++ b/miplearn/__init__.py @@ -9,7 +9,10 @@ from .components.warmstart import (WarmStartComponent, AdaptivePredictor, ) from .components.branching import BranchPriorityComponent -from .extractors import UserFeaturesExtractor, SolutionExtractor +from .extractors import (UserFeaturesExtractor, + SolutionExtractor, + CombinedExtractor, + ) from .benchmark import BenchmarkRunner from .instance import Instance from .solvers import LearningSolver diff --git a/miplearn/components/warmstart.py b/miplearn/components/warmstart.py index 326f2ac..ed0a18c 100644 --- a/miplearn/components/warmstart.py +++ b/miplearn/components/warmstart.py @@ -128,15 +128,9 @@ class WarmStartComponent(Component): def before_solve(self, solver, instance, model): -# # Solve linear relaxation -# lr_solver = pe.SolverFactory("gurobi") -# lr_solver.options["threads"] = 4 -# lr_solver.options["relax_integrality"] = 1 -# lr_solver.solve(model, tee=solver.tee) - # Build x_test x_test = CombinedExtractor([UserFeaturesExtractor(), - SolutionExtractor(), + SolutionExtractor(relaxation=True), ]).extract([instance], [model]) # Update self.x_train diff --git a/miplearn/extractors.py b/miplearn/extractors.py index 1b6f529..2803c7f 100644 --- a/miplearn/extractors.py +++ b/miplearn/extractors.py @@ -69,8 +69,13 @@ class UserFeaturesExtractor(Extractor): class SolutionExtractor(Extractor): - def extract(self, instances, models): + def __init__(self, relaxation=False): + self.relaxation = relaxation + + def extract(self, instances, models=None): result = {} + if models is None: + models = [instance.to_model() for instance in instances] for (index, instance) in enumerate(instances): model = models[index] var_split = self.split_variables(instance, model) @@ -78,7 +83,10 @@ class SolutionExtractor(Extractor): if category not in result.keys(): result[category] = [] for (var, index) in var_index_pairs: - v = var[index].value + if self.relaxation: + v = instance.lp_solution[str(var)][index] + else: + v = instance.solution[str(var)][index] if v is None: result[category] += [[0, 0]] else: diff --git a/miplearn/solvers.py b/miplearn/solvers.py index 8481c29..1d5c0fb 100644 --- a/miplearn/solvers.py +++ b/miplearn/solvers.py @@ -168,8 +168,10 @@ class LearningSolver: solver.set_gap_tolerance(self.gap_tolerance) return solver - def solve(self, instance, tee=False): - model = instance.to_model() + def solve(self, instance, model=None, tee=False): + if model is None: + model = instance.to_model() + self.tee = tee self.internal_solver = self._create_internal_solver() diff --git a/miplearn/tests/test_extractors.py b/miplearn/tests/test_extractors.py index cd7f015..915a602 100644 --- a/miplearn/tests/test_extractors.py +++ b/miplearn/tests/test_extractors.py @@ -3,10 +3,11 @@ # Released under the modified BSD license. See COPYING.md for more details. from miplearn.problems.knapsack import KnapsackInstance -from miplearn.extractors import (UserFeaturesExtractor, - SolutionExtractor, - CombinedExtractor, - ) +from miplearn import (LearningSolver, + UserFeaturesExtractor, + SolutionExtractor, + CombinedExtractor, + ) import numpy as np import pyomo.environ as pe @@ -37,11 +38,11 @@ def test_user_features(): def test_solution_extractor(): instances = _get_instances() models = [instance.to_model() for instance in instances] - for model in models: - solver = pe.SolverFactory("cbc") - solver.solve(model) - extractor = SolutionExtractor() - features = extractor.extract(instances, models) + solver = LearningSolver() + for (i, instance) in enumerate(instances): + solver.solve(instances[i], models[i]) + + features = SolutionExtractor().extract(instances, models) assert isinstance(features, dict) assert "default" in features.keys() assert isinstance(features["default"], np.ndarray) @@ -59,6 +60,10 @@ def test_solution_extractor(): def test_combined_extractor(): instances = _get_instances() models = [instance.to_model() for instance in instances] + solver = LearningSolver() + for (i, instance) in enumerate(instances): + solver.solve(instances[i], models[i]) + extractor = CombinedExtractor(extractors=[UserFeaturesExtractor(), SolutionExtractor()]) features = extractor.extract(instances, models)