mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Make WarmStartComponent use Extractor
This commit is contained in:
@@ -25,6 +25,23 @@ class Extractor(ABC):
|
|||||||
result[category] += [(var, index)]
|
result[category] += [(var, index)]
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def merge(partial_results, vertical=False):
|
||||||
|
results = {}
|
||||||
|
all_categories = set()
|
||||||
|
for pr in partial_results:
|
||||||
|
all_categories |= pr.keys()
|
||||||
|
for category in all_categories:
|
||||||
|
results[category] = []
|
||||||
|
for pr in partial_results:
|
||||||
|
if category in pr.keys():
|
||||||
|
results[category] += [pr[category]]
|
||||||
|
if vertical:
|
||||||
|
results[category] = np.vstack(results[category])
|
||||||
|
else:
|
||||||
|
results[category] = np.hstack(results[category])
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
class UserFeaturesExtractor(Extractor):
|
class UserFeaturesExtractor(Extractor):
|
||||||
def extract(self,
|
def extract(self,
|
||||||
@@ -61,10 +78,20 @@ class SolutionExtractor(Extractor):
|
|||||||
if category not in result.keys():
|
if category not in result.keys():
|
||||||
result[category] = []
|
result[category] = []
|
||||||
for (var, index) in var_index_pairs:
|
for (var, index) in var_index_pairs:
|
||||||
result[category] += [[
|
v = var[index].value
|
||||||
1 - var[index].value,
|
if v is None:
|
||||||
var[index].value,
|
result[category] += [[0, 0]]
|
||||||
]]
|
else:
|
||||||
|
result[category] += [[1 - v, v]]
|
||||||
for category in result.keys():
|
for category in result.keys():
|
||||||
result[category] = np.vstack(result[category])
|
result[category] = np.vstack(result[category])
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class CombinedExtractor(Extractor):
|
||||||
|
def __init__(self, extractors):
|
||||||
|
self.extractors = extractors
|
||||||
|
|
||||||
|
def extract(self, instances, models):
|
||||||
|
return self.merge([ex.extract(instances, models)
|
||||||
|
for ex in self.extractors])
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
# Copyright (C) 2019-2020 Argonne National Laboratory. All rights reserved.
|
# Copyright (C) 2019-2020 Argonne National Laboratory. All rights reserved.
|
||||||
# Written by Alinson S. Xavier <axavier@anl.gov>
|
# Written by Alinson S. Xavier <axavier@anl.gov>
|
||||||
|
|
||||||
from .transformers import PerVariableTransformer
|
|
||||||
from .warmstart import WarmStartComponent
|
from .warmstart import WarmStartComponent
|
||||||
from .branching import BranchPriorityComponent
|
from .branching import BranchPriorityComponent
|
||||||
import pyomo.environ as pe
|
import pyomo.environ as pe
|
||||||
@@ -66,6 +65,12 @@ class LearningSolver:
|
|||||||
def solve(self, instance, tee=False):
|
def solve(self, instance, tee=False):
|
||||||
model = instance.to_model()
|
model = instance.to_model()
|
||||||
|
|
||||||
|
# Solve linear relaxation (TODO: use solver provided by user)
|
||||||
|
lr_solver = pe.SolverFactory("gurobi")
|
||||||
|
lr_solver.options["threads"] = 4
|
||||||
|
lr_solver.options["relax_integrality"] = 1
|
||||||
|
lr_solver.solve(model)
|
||||||
|
|
||||||
self._create_solver()
|
self._create_solver()
|
||||||
if self.is_persistent:
|
if self.is_persistent:
|
||||||
self.internal_solver.set_instance(model)
|
self.internal_solver.set_instance(model)
|
||||||
|
|||||||
@@ -3,8 +3,10 @@
|
|||||||
# Written by Alinson S. Xavier <axavier@anl.gov>
|
# Written by Alinson S. Xavier <axavier@anl.gov>
|
||||||
|
|
||||||
from miplearn.problems.knapsack import KnapsackInstance
|
from miplearn.problems.knapsack import KnapsackInstance
|
||||||
from miplearn import (UserFeaturesExtractor,
|
from miplearn.extractors import (UserFeaturesExtractor,
|
||||||
SolutionExtractor)
|
SolutionExtractor,
|
||||||
|
CombinedExtractor,
|
||||||
|
)
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pyomo.environ as pe
|
import pyomo.environ as pe
|
||||||
|
|
||||||
@@ -52,3 +54,16 @@ def test_solution_extractor():
|
|||||||
0., 1.,
|
0., 1.,
|
||||||
1., 0.,
|
1., 0.,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_combined_extractor():
|
||||||
|
instances = _get_instances()
|
||||||
|
models = [instance.to_model() for instance in instances]
|
||||||
|
extractor = CombinedExtractor(extractors=[UserFeaturesExtractor(),
|
||||||
|
SolutionExtractor()])
|
||||||
|
features = extractor.extract(instances, models)
|
||||||
|
assert isinstance(features, dict)
|
||||||
|
assert "default" in features.keys()
|
||||||
|
assert isinstance(features["default"], np.ndarray)
|
||||||
|
assert features["default"].shape == (6, 6)
|
||||||
|
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ def test_warm_start_save_load():
|
|||||||
solver.parallel_solve(_get_instances(), n_jobs=2)
|
solver.parallel_solve(_get_instances(), n_jobs=2)
|
||||||
solver.fit()
|
solver.fit()
|
||||||
comp = solver.components["warm-start"]
|
comp = solver.components["warm-start"]
|
||||||
assert comp.x_train["default"].shape == (8, 4)
|
assert comp.x_train["default"].shape == (8, 6)
|
||||||
assert comp.y_train["default"].shape == (8, 2)
|
assert comp.y_train["default"].shape == (8, 2)
|
||||||
assert "default" in comp.predictors.keys()
|
assert "default" in comp.predictors.keys()
|
||||||
solver.save_state(state_file.name)
|
solver.save_state(state_file.name)
|
||||||
@@ -32,6 +32,6 @@ def test_warm_start_save_load():
|
|||||||
solver = LearningSolver(components={"warm-start": WarmStartComponent()})
|
solver = LearningSolver(components={"warm-start": WarmStartComponent()})
|
||||||
solver.load_state(state_file.name)
|
solver.load_state(state_file.name)
|
||||||
comp = solver.components["warm-start"]
|
comp = solver.components["warm-start"]
|
||||||
assert comp.x_train["default"].shape == (8, 4)
|
assert comp.x_train["default"].shape == (8, 6)
|
||||||
assert comp.y_train["default"].shape == (8, 2)
|
assert comp.y_train["default"].shape == (8, 2)
|
||||||
assert "default" in comp.predictors.keys()
|
assert "default" in comp.predictors.keys()
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
|
|
||||||
from . import Component
|
from . import Component
|
||||||
from .transformers import PerVariableTransformer
|
from .transformers import PerVariableTransformer
|
||||||
|
from .extractors import *
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
@@ -131,52 +132,43 @@ class WarmStartComponent(Component):
|
|||||||
self.predictor_prototype = predictor_prototype
|
self.predictor_prototype = predictor_prototype
|
||||||
|
|
||||||
def before_solve(self, solver, instance, model):
|
def before_solve(self, solver, instance, model):
|
||||||
var_split = self.transformer.split_variables(instance, model)
|
# Build x_test
|
||||||
x_test = {}
|
x_test = CombinedExtractor([UserFeaturesExtractor(),
|
||||||
|
SolutionExtractor(),
|
||||||
|
]).extract([instance], [model])
|
||||||
|
|
||||||
# Collect training data (x_train) and build x_test
|
# Update self.x_train
|
||||||
for category in var_split.keys():
|
self.x_train = Extractor.merge([self.x_train, x_test],
|
||||||
var_index_pairs = var_split[category]
|
vertical=True)
|
||||||
x = self.transformer.transform_instance(instance, var_index_pairs)
|
|
||||||
x_test[category] = x
|
|
||||||
if category not in self.x_train.keys():
|
|
||||||
self.x_train[category] = x
|
|
||||||
else:
|
|
||||||
assert x.shape[1] == self.x_train[category].shape[1]
|
|
||||||
self.x_train[category] = np.vstack([self.x_train[category], x])
|
|
||||||
|
|
||||||
# Predict solutions
|
# Predict solutions
|
||||||
|
var_split = Extractor.split_variables(instance, model)
|
||||||
for category in var_split.keys():
|
for category in var_split.keys():
|
||||||
var_index_pairs = var_split[category]
|
var_index_pairs = var_split[category]
|
||||||
if category in self.predictors.keys():
|
if category not in self.predictors.keys():
|
||||||
ws = self.predictors[category].predict(x_test[category])
|
continue
|
||||||
assert ws.shape == (len(var_index_pairs), 2)
|
ws = self.predictors[category].predict(x_test[category])
|
||||||
for i in range(len(var_index_pairs)):
|
assert ws.shape == (len(var_index_pairs), 2)
|
||||||
var, index = var_index_pairs[i]
|
for i in range(len(var_index_pairs)):
|
||||||
if self.mode == "heuristic":
|
var, index = var_index_pairs[i]
|
||||||
if ws[i,0] == 1:
|
if self.mode == "heuristic":
|
||||||
var[index].fix(0)
|
if ws[i,0] == 1:
|
||||||
if solver.is_persistent:
|
var[index].fix(0)
|
||||||
solver.internal_solver.update_var(var[index])
|
if solver.is_persistent:
|
||||||
elif ws[i,1] == 1:
|
solver.internal_solver.update_var(var[index])
|
||||||
var[index].fix(1)
|
elif ws[i,1] == 1:
|
||||||
if solver.is_persistent:
|
var[index].fix(1)
|
||||||
solver.internal_solver.update_var(var[index])
|
if solver.is_persistent:
|
||||||
else:
|
solver.internal_solver.update_var(var[index])
|
||||||
if ws[i,0] == 1:
|
else:
|
||||||
var[index].value = 0
|
if ws[i,0] == 1:
|
||||||
elif ws[i,1] == 1:
|
var[index].value = 0
|
||||||
var[index].value = 1
|
elif ws[i,1] == 1:
|
||||||
|
var[index].value = 1
|
||||||
|
|
||||||
def after_solve(self, solver, instance, model):
|
def after_solve(self, solver, instance, model):
|
||||||
var_split = self.transformer.split_variables(instance, model)
|
y_test = SolutionExtractor().extract([instance], [model])
|
||||||
for category in var_split.keys():
|
self.y_train = Extractor.merge([self.y_train, y_test], vertical=True)
|
||||||
var_index_pairs = var_split[category]
|
|
||||||
y = self.transformer.transform_solution(var_index_pairs)
|
|
||||||
if category not in self.y_train.keys():
|
|
||||||
self.y_train[category] = y
|
|
||||||
else:
|
|
||||||
self.y_train[category] = np.vstack([self.y_train[category], y])
|
|
||||||
|
|
||||||
def fit(self, solver, n_jobs=1):
|
def fit(self, solver, n_jobs=1):
|
||||||
for category in tqdm(self.x_train.keys(), desc="Warm start"):
|
for category in tqdm(self.x_train.keys(), desc="Warm start"):
|
||||||
|
|||||||
Reference in New Issue
Block a user