mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Modularize LearningSolver into components; implement branch-priority
This commit is contained in:
2
Makefile
2
Makefile
@@ -1,4 +1,4 @@
|
||||
PYTEST_ARGS := -W ignore::DeprecationWarning --capture=no -vv
|
||||
PYTEST_ARGS := -W ignore::DeprecationWarning -vv
|
||||
|
||||
test:
|
||||
pytest $(PYTEST_ARGS)
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
# Copyright (C) 2019-2020 Argonne National Laboratory. All rights reserved.
|
||||
# Written by Alinson S. Xavier <axavier@anl.gov>
|
||||
|
||||
|
||||
from .component import Component
|
||||
from .instance import Instance
|
||||
from .solvers import LearningSolver
|
||||
from .benchmark import BenchmarkRunner
|
||||
@@ -13,11 +13,6 @@ class BenchmarkRunner:
|
||||
self.solvers = solvers
|
||||
self.results = None
|
||||
|
||||
def load_fit(self, filename):
|
||||
for (name, solver) in self.solvers.items():
|
||||
solver.load(filename)
|
||||
solver.fit()
|
||||
|
||||
def parallel_solve(self, instances, n_jobs=1, n_trials=1):
|
||||
if self.results is None:
|
||||
self.results = pd.DataFrame(columns=["Solver",
|
||||
@@ -79,3 +74,11 @@ class BenchmarkRunner:
|
||||
|
||||
def load_results(self, filename):
|
||||
self.results = pd.read_csv(filename, index_col=0)
|
||||
|
||||
def load_state(self, filename):
|
||||
for (name, solver) in self.solvers.items():
|
||||
solver.load_state(filename)
|
||||
|
||||
def fit(self):
|
||||
for (name, solver) in self.solvers.items():
|
||||
solver.fit()
|
||||
|
||||
63
miplearn/branching.py
Normal file
63
miplearn/branching.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# MIPLearn, an extensible framework for Learning-Enhanced Mixed-Integer Optimization
|
||||
# Copyright (C) 2019-2020 Argonne National Laboratory. All rights reserved.
|
||||
# Written by Alinson S. Xavier <axavier@anl.gov>
|
||||
|
||||
from . import Component
|
||||
from .transformers import PerVariableTransformer
|
||||
from abc import ABC, abstractmethod
|
||||
import numpy as np
|
||||
|
||||
|
||||
class BranchPriorityComponent(Component):
|
||||
def __init__(self,
|
||||
initial_priority=None,
|
||||
collect_training_data=True):
|
||||
self.priority = initial_priority
|
||||
self.transformer = PerVariableTransformer()
|
||||
self.collect_training_data = collect_training_data
|
||||
|
||||
def before_solve(self, solver, instance, model):
|
||||
assert solver.is_persistent, "BranchPriorityComponent requires a persistent solver"
|
||||
var_split = self.transformer.split_variables(instance, model)
|
||||
for category in var_split.keys():
|
||||
var_index_pairs = var_split[category]
|
||||
if self.priority is not None:
|
||||
from gurobipy import GRB
|
||||
for (i, (var, index)) in enumerate(var_index_pairs):
|
||||
gvar = solver.internal_solver._pyomo_var_to_solver_var_map[var[index]]
|
||||
gvar.setAttr(GRB.Attr.BranchPriority, int(self.priority[index]))
|
||||
|
||||
|
||||
def after_solve(self, solver, instance, model):
|
||||
if self.collect_training_data:
|
||||
import subprocess, tempfile, os
|
||||
src_dirname = os.path.dirname(os.path.realpath(__file__))
|
||||
model_file = tempfile.NamedTemporaryFile(suffix=".lp")
|
||||
priority_file = tempfile.NamedTemporaryFile()
|
||||
solver.internal_solver.write(model_file.name)
|
||||
subprocess.run(["julia",
|
||||
"%s/scripts/branchpriority.jl" % src_dirname,
|
||||
model_file.name,
|
||||
priority_file.name],
|
||||
check=True)
|
||||
self._merge(np.genfromtxt(priority_file.name,
|
||||
delimiter=',',
|
||||
dtype=int))
|
||||
|
||||
|
||||
def fit(self, solver):
|
||||
pass
|
||||
|
||||
|
||||
def merge(self, other):
|
||||
if other.priority is not None:
|
||||
self._merge(other.priority)
|
||||
|
||||
|
||||
def _merge(self, priority):
|
||||
assert isinstance(priority, np.ndarray)
|
||||
if self.priority is None:
|
||||
self.priority = priority
|
||||
else:
|
||||
assert self.priority.shape == priority.shape
|
||||
self.priority += priority
|
||||
23
miplearn/component.py
Normal file
23
miplearn/component.py
Normal file
@@ -0,0 +1,23 @@
|
||||
# MIPLearn, an extensible framework for Learning-Enhanced Mixed-Integer Optimization
|
||||
# Copyright (C) 2019-2020 Argonne National Laboratory. All rights reserved.
|
||||
# Written by Alinson S. Xavier <axavier@anl.gov>
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class Component(ABC):
|
||||
@abstractmethod
|
||||
def fit(self, solver):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def before_solve(self, solver, instance, model):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def after_solve(self, solver, instance, model):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def merge(self, other):
|
||||
pass
|
||||
@@ -12,11 +12,11 @@ from scipy.stats import uniform, randint
|
||||
|
||||
def test_stab():
|
||||
graph = nx.cycle_graph(5)
|
||||
weights = [1., 2., 3., 4., 5.]
|
||||
weights = [1., 1., 1., 1., 1.]
|
||||
instance = MaxWeightStableSetInstance(graph, weights)
|
||||
solver = LearningSolver()
|
||||
solver.solve(instance)
|
||||
assert instance.model.OBJ() == 8.
|
||||
assert instance.model.OBJ() == 2.
|
||||
|
||||
|
||||
def test_stab_generator_fixed_graph():
|
||||
|
||||
67
miplearn/scripts/branchpriority.jl
Normal file
67
miplearn/scripts/branchpriority.jl
Normal file
@@ -0,0 +1,67 @@
|
||||
import Base.Threads.@threads
|
||||
using TinyBnB, CPLEXW, Printf
|
||||
|
||||
instance_name = ARGS[1]
|
||||
output_filename = ARGS[2]
|
||||
|
||||
mip = open_mip(instance_name)
|
||||
n_vars = CPXgetnumcols(mip.cplex_env[1], mip.cplex_lp[1])
|
||||
|
||||
pseudocost_count_up = [0 for i in 1:n_vars]
|
||||
pseudocost_count_down = [0 for i in 1:n_vars]
|
||||
pseudocost_sum_up = [0. for i in 1:n_vars]
|
||||
pseudocost_sum_down = [0. for i in 1:n_vars]
|
||||
|
||||
function full_strong_branching_track(node::Node, progress::Progress)::TinyBnB.Variable
|
||||
N = length(node.fractional_variables)
|
||||
scores = Array{Float64}(undef, N)
|
||||
rates_up = Array{Float64}(undef, N)
|
||||
rates_down = Array{Float64}(undef, N)
|
||||
|
||||
@threads for v in 1:N
|
||||
fix_vars!(node.mip, node.branch_variables, node.branch_values)
|
||||
obj_up, obj_down = TinyBnB.probe(node.mip, node.fractional_variables[v])
|
||||
unfix_vars!(node.mip, node.branch_variables)
|
||||
delta_up = obj_up - node.obj
|
||||
delta_down = obj_down - node.obj
|
||||
frac_up = ceil(node.fractional_values[v]) - node.fractional_values[v]
|
||||
frac_down = node.fractional_values[v] - floor(node.fractional_values[v])
|
||||
rates_up[v] = delta_up / frac_up
|
||||
rates_down[v] = delta_down / frac_down
|
||||
scores[v] = delta_up * delta_down
|
||||
end
|
||||
|
||||
max_score, max_offset = findmax(scores)
|
||||
selected_var = node.fractional_variables[max_offset]
|
||||
|
||||
if rates_up[max_offset] < 1e6
|
||||
pseudocost_count_up[selected_var.index] += 1
|
||||
pseudocost_sum_up[selected_var.index] += rates_up[max_offset]
|
||||
end
|
||||
|
||||
if rates_down[max_offset] < 1e6
|
||||
pseudocost_count_down[selected_var.index] += 1
|
||||
pseudocost_sum_down[selected_var.index] += rates_down[max_offset]
|
||||
end
|
||||
|
||||
return selected_var
|
||||
end
|
||||
|
||||
branch_and_bound(mip,
|
||||
node_limit = 1000,
|
||||
branch_rule = full_strong_branching_track,
|
||||
node_rule = best_bound,
|
||||
print_interval = 1)
|
||||
|
||||
priority = [(pseudocost_count_up[v] == 0 || pseudocost_count_down[v] == 0) ? 0 :
|
||||
(pseudocost_sum_up[v] / pseudocost_count_up[v]) *
|
||||
(pseudocost_sum_down[v] / pseudocost_count_down[v])
|
||||
for v in 1:n_vars];
|
||||
|
||||
open(output_filename, "w") do file
|
||||
for v in 1:n_vars
|
||||
v == 1 || write(file, ",")
|
||||
write(file, @sprintf("%.0f", priority[v]))
|
||||
end
|
||||
write(file, "\n")
|
||||
end
|
||||
@@ -3,10 +3,11 @@
|
||||
# Written by Alinson S. Xavier <axavier@anl.gov>
|
||||
|
||||
from .transformers import PerVariableTransformer
|
||||
from .warmstart import KnnWarmStartPredictor, LogisticWarmStartPredictor
|
||||
from .warmstart import WarmStartComponent
|
||||
from .branching import BranchPriorityComponent
|
||||
import pyomo.environ as pe
|
||||
import numpy as np
|
||||
from copy import copy, deepcopy
|
||||
from copy import deepcopy
|
||||
import pickle
|
||||
from tqdm import tqdm
|
||||
from joblib import Parallel, delayed
|
||||
@@ -20,6 +21,7 @@ def _gurobi_factory():
|
||||
solver.options["Seed"] = randint(low=0, high=1000).rvs()
|
||||
return solver
|
||||
|
||||
|
||||
class LearningSolver:
|
||||
"""
|
||||
Mixed-Integer Linear Programming (MIP) solver that extracts information from previous runs,
|
||||
@@ -29,144 +31,95 @@ class LearningSolver:
|
||||
def __init__(self,
|
||||
threads=4,
|
||||
internal_solver_factory=_gurobi_factory,
|
||||
ws_predictor=LogisticWarmStartPredictor(),
|
||||
branch_priority=None,
|
||||
mode="exact"):
|
||||
components=None,
|
||||
mode=None):
|
||||
self.is_persistent = None
|
||||
self.internal_solver = None
|
||||
self.components = components
|
||||
self.internal_solver_factory = internal_solver_factory
|
||||
|
||||
if self.components is not None:
|
||||
assert isinstance(self.components, dict)
|
||||
else:
|
||||
self.components = {
|
||||
"warm-start": WarmStartComponent(),
|
||||
"branch-priority": BranchPriorityComponent(),
|
||||
}
|
||||
|
||||
if mode is not None:
|
||||
assert mode in ["exact", "heuristic"]
|
||||
for component in self.components:
|
||||
component.mode = mode
|
||||
|
||||
def _create_solver(self):
|
||||
self.internal_solver = self.internal_solver_factory()
|
||||
self.mode = mode
|
||||
self.x_train = {}
|
||||
self.y_train = {}
|
||||
self.ws_predictors = {}
|
||||
self.ws_predictor_prototype = ws_predictor
|
||||
self.branch_priority = branch_priority
|
||||
self.is_persistent = hasattr(self.internal_solver, "set_instance")
|
||||
|
||||
def _clear(self):
|
||||
self.internal_solver = None
|
||||
|
||||
def solve(self, instance, tee=False):
|
||||
# Load model into solver
|
||||
model = instance.to_model()
|
||||
is_solver_persistent = hasattr(self.internal_solver, "set_instance")
|
||||
if is_solver_persistent:
|
||||
|
||||
self._create_solver()
|
||||
if self.is_persistent:
|
||||
self.internal_solver.set_instance(model)
|
||||
|
||||
# Split decision variables according to their category
|
||||
transformer = PerVariableTransformer()
|
||||
var_split = transformer.split_variables(instance, model)
|
||||
for component in self.components.values():
|
||||
component.before_solve(self, instance, model)
|
||||
|
||||
# Build x_test and update x_train
|
||||
x_test = {}
|
||||
for category in var_split.keys():
|
||||
var_index_pairs = var_split[category]
|
||||
x = transformer.transform_instance(instance, var_index_pairs)
|
||||
x_test[category] = x
|
||||
if category not in self.x_train.keys():
|
||||
self.x_train[category] = x
|
||||
else:
|
||||
self.x_train[category] = np.vstack([self.x_train[category], x])
|
||||
|
||||
for category in var_split.keys():
|
||||
var_index_pairs = var_split[category]
|
||||
|
||||
# Predict warm starts
|
||||
if category in self.ws_predictors.keys():
|
||||
ws = self.ws_predictors[category].predict(x_test[category])
|
||||
assert ws.shape == (len(var_index_pairs), 2)
|
||||
for i in range(len(var_index_pairs)):
|
||||
var, index = var_index_pairs[i]
|
||||
if self.mode == "heuristic":
|
||||
if ws[i,0] == 1:
|
||||
var[index].fix(0)
|
||||
elif ws[i,1] == 1:
|
||||
var[index].fix(1)
|
||||
else:
|
||||
if ws[i,0] == 1:
|
||||
var[index].value = 0
|
||||
elif ws[i,1] == 1:
|
||||
var[index].value = 1
|
||||
|
||||
# Set custom branch priority
|
||||
if self.branch_priority is not None:
|
||||
assert is_solver_persistent
|
||||
from gurobipy import GRB
|
||||
for (i, (var, index)) in enumerate(var_index_pairs):
|
||||
gvar = self.internal_solver._pyomo_var_to_solver_var_map[var[index]]
|
||||
#priority = randint(low=0, high=1000).rvs()
|
||||
gvar.setAttr(GRB.Attr.BranchPriority, self.branch_priority[index])
|
||||
|
||||
if is_solver_persistent:
|
||||
if self.is_persistent:
|
||||
solve_results = self.internal_solver.solve(tee=tee, warmstart=True)
|
||||
else:
|
||||
solve_results = self.internal_solver.solve(model, tee=tee, warmstart=True)
|
||||
|
||||
solve_results["Solver"][0]["Nodes"] = self.internal_solver._solver_model.getAttr("NodeCount")
|
||||
|
||||
|
||||
# Update y_train
|
||||
for category in var_split.keys():
|
||||
var_index_pairs = var_split[category]
|
||||
y = transformer.transform_solution(var_index_pairs)
|
||||
if category not in self.y_train.keys():
|
||||
self.y_train[category] = y
|
||||
else:
|
||||
self.y_train[category] = np.vstack([self.y_train[category], y])
|
||||
for component in self.components.values():
|
||||
component.after_solve(self, instance, model)
|
||||
|
||||
return solve_results
|
||||
|
||||
def parallel_solve(self, instances, n_jobs=4, label="Solve"):
|
||||
self.parentSolver = None
|
||||
self._clear()
|
||||
|
||||
def _process(instance):
|
||||
solver = copy(self)
|
||||
solver.internal_solver = solver.internal_solver_factory()
|
||||
solver = deepcopy(self)
|
||||
results = solver.solve(instance)
|
||||
return {
|
||||
"x_train": solver.x_train,
|
||||
"y_train": solver.y_train,
|
||||
"results": results,
|
||||
}
|
||||
solver._clear()
|
||||
return solver, results
|
||||
|
||||
def _merge(results):
|
||||
categories = results[0]["x_train"].keys()
|
||||
x_entries = [np.vstack([r["x_train"][c] for r in results]) for c in categories]
|
||||
y_entries = [np.vstack([r["y_train"][c] for r in results]) for c in categories]
|
||||
x_train = dict(zip(categories, x_entries))
|
||||
y_train = dict(zip(categories, y_entries))
|
||||
results = [r["results"] for r in results]
|
||||
return x_train, y_train, results
|
||||
|
||||
results = Parallel(n_jobs=n_jobs)(
|
||||
solver_result_pairs = Parallel(n_jobs=n_jobs)(
|
||||
delayed(_process)(instance)
|
||||
for instance in tqdm(instances, desc=label, ncols=80)
|
||||
)
|
||||
|
||||
x_train, y_train, results = _merge(results)
|
||||
self.x_train = x_train
|
||||
self.y_train = y_train
|
||||
solvers = [p[0] for p in solver_result_pairs]
|
||||
results = [p[1] for p in solver_result_pairs]
|
||||
|
||||
for (name, component) in self.components.items():
|
||||
for subsolver in solvers:
|
||||
self.components[name].merge(subsolver.components[name])
|
||||
|
||||
return results
|
||||
|
||||
def fit(self, x_train_dict=None, y_train_dict=None):
|
||||
if x_train_dict is None:
|
||||
x_train_dict = self.x_train
|
||||
y_train_dict = self.y_train
|
||||
for category in x_train_dict.keys():
|
||||
x_train = x_train_dict[category]
|
||||
y_train = y_train_dict[category]
|
||||
if self.ws_predictor_prototype is not None:
|
||||
self.ws_predictors[category] = deepcopy(self.ws_predictor_prototype)
|
||||
self.ws_predictors[category].fit(x_train, y_train)
|
||||
def fit(self):
|
||||
for component in self.components.values():
|
||||
component.fit(self)
|
||||
|
||||
def save(self, filename):
|
||||
def save_state(self, filename):
|
||||
with open(filename, "wb") as file:
|
||||
pickle.dump({
|
||||
"version": 1,
|
||||
"x_train": self.x_train,
|
||||
"y_train": self.y_train,
|
||||
"ws_predictors": self.ws_predictors,
|
||||
"version": 2,
|
||||
"components": self.components,
|
||||
}, file)
|
||||
|
||||
def load(self, filename):
|
||||
def load_state(self, filename):
|
||||
with open(filename, "rb") as file:
|
||||
data = pickle.load(file)
|
||||
assert data["version"] == 1
|
||||
self.x_train = data["x_train"]
|
||||
self.y_train = data["y_train"]
|
||||
self.ws_predictors = self.ws_predictors
|
||||
assert data["version"] == 2
|
||||
for (component_name, component) in data["components"].items():
|
||||
if component_name not in self.components.keys():
|
||||
continue
|
||||
else:
|
||||
self.components[component_name].merge(component)
|
||||
|
||||
@@ -19,15 +19,16 @@ def test_benchmark():
|
||||
# Training phase...
|
||||
training_solver = LearningSolver()
|
||||
training_solver.parallel_solve(train_instances, n_jobs=10)
|
||||
training_solver.save("data.bin")
|
||||
training_solver.fit()
|
||||
training_solver.save_state("data.bin")
|
||||
|
||||
# Test phase...
|
||||
test_solvers = {
|
||||
"Strategy A": LearningSolver(ws_predictor=None),
|
||||
"Strategy B": LearningSolver(ws_predictor=None),
|
||||
"Strategy A": LearningSolver(),
|
||||
"Strategy B": LearningSolver(),
|
||||
}
|
||||
benchmark = BenchmarkRunner(test_solvers)
|
||||
benchmark.load_fit("data.bin")
|
||||
benchmark.load_state("data.bin")
|
||||
benchmark.parallel_solve(test_instances, n_jobs=2, n_trials=2)
|
||||
assert benchmark.raw_results().values.shape == (12,12)
|
||||
|
||||
|
||||
@@ -4,6 +4,8 @@
|
||||
|
||||
from miplearn import LearningSolver
|
||||
from miplearn.problems.knapsack import KnapsackInstance2
|
||||
from miplearn.branching import BranchPriorityComponent
|
||||
from miplearn.warmstart import WarmStartComponent
|
||||
import numpy as np
|
||||
|
||||
|
||||
@@ -16,21 +18,29 @@ def test_solver():
|
||||
solver.fit()
|
||||
solver.solve(instance)
|
||||
|
||||
def test_solve_save_load():
|
||||
def test_solve_save_load_state():
|
||||
instance = KnapsackInstance2(weights=[23., 26., 20., 18.],
|
||||
prices=[505., 352., 458., 220.],
|
||||
capacity=67.)
|
||||
solver = LearningSolver()
|
||||
components_before = {
|
||||
"warm-start": WarmStartComponent(),
|
||||
"branch-priority": BranchPriorityComponent(),
|
||||
}
|
||||
solver = LearningSolver(components=components_before)
|
||||
solver.solve(instance)
|
||||
solver.fit()
|
||||
solver.save("/tmp/knapsack_train.bin")
|
||||
prev_x_train_len = len(solver.x_train)
|
||||
prev_y_train_len = len(solver.y_train)
|
||||
solver.save_state("/tmp/knapsack_train.bin")
|
||||
prev_x_train_len = len(solver.components["warm-start"].x_train)
|
||||
prev_y_train_len = len(solver.components["warm-start"].y_train)
|
||||
|
||||
solver = LearningSolver()
|
||||
solver.load("/tmp/knapsack_train.bin")
|
||||
assert len(solver.x_train) == prev_x_train_len
|
||||
assert len(solver.y_train) == prev_y_train_len
|
||||
components_after = {
|
||||
"warm-start": WarmStartComponent(),
|
||||
}
|
||||
solver = LearningSolver(components=components_after)
|
||||
solver.load_state("/tmp/knapsack_train.bin")
|
||||
assert len(solver.components.keys()) == 1
|
||||
assert len(solver.components["warm-start"].x_train) == prev_x_train_len
|
||||
assert len(solver.components["warm-start"].y_train) == prev_y_train_len
|
||||
|
||||
def test_parallel_solve():
|
||||
instances = [KnapsackInstance2(weights=np.random.rand(5),
|
||||
@@ -38,13 +48,18 @@ def test_parallel_solve():
|
||||
capacity=3.0)
|
||||
for _ in range(10)]
|
||||
solver = LearningSolver()
|
||||
solver.parallel_solve(instances, n_jobs=3)
|
||||
assert len(solver.x_train[0]) == 10
|
||||
assert len(solver.y_train[0]) == 10
|
||||
results = solver.parallel_solve(instances, n_jobs=3)
|
||||
assert len(results) == 10
|
||||
assert len(solver.components["warm-start"].x_train[0]) == 10
|
||||
assert len(solver.components["warm-start"].y_train[0]) == 10
|
||||
|
||||
def test_solver_random_branch_priority():
|
||||
instance = KnapsackInstance2(weights=[23., 26., 20., 18.],
|
||||
prices=[505., 352., 458., 220.],
|
||||
capacity=67.)
|
||||
solver = LearningSolver(branch_priority=[1, 2, 3, 4])
|
||||
solver.solve(instance, tee=True)
|
||||
components = {
|
||||
"warm-start": BranchPriorityComponent(priority=np.array([1, 2, 3, 4])),
|
||||
}
|
||||
solver = LearningSolver(components=components)
|
||||
solver.solve(instance)
|
||||
solver.fit()
|
||||
@@ -2,7 +2,11 @@
|
||||
# Copyright (C) 2019-2020 Argonne National Laboratory. All rights reserved.
|
||||
# Written by Alinson S. Xavier <axavier@anl.gov>
|
||||
|
||||
from . import Component
|
||||
from .transformers import PerVariableTransformer
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from copy import deepcopy
|
||||
import numpy as np
|
||||
from sklearn.pipeline import make_pipeline
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
@@ -10,6 +14,7 @@ from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.model_selection import cross_val_score
|
||||
from sklearn.neighbors import KNeighborsClassifier
|
||||
|
||||
|
||||
class WarmStartPredictor(ABC):
|
||||
def __init__(self, thr_clip=[0.50, 0.50]):
|
||||
self.models = [None, None]
|
||||
@@ -106,3 +111,78 @@ class KnnWarmStartPredictor(WarmStartPredictor):
|
||||
knn = KNeighborsClassifier(n_neighbors=self.k)
|
||||
knn.fit(x_train, y_train)
|
||||
return knn
|
||||
|
||||
|
||||
class WarmStartComponent(Component):
|
||||
def __init__(self,
|
||||
predictor_prototype=LogisticWarmStartPredictor(),
|
||||
mode="exact",
|
||||
):
|
||||
self.mode = mode
|
||||
self.transformer = PerVariableTransformer()
|
||||
self.x_train = {}
|
||||
self.y_train = {}
|
||||
self.predictors = {}
|
||||
self.predictor_prototype = predictor_prototype
|
||||
|
||||
def before_solve(self, solver, instance, model):
|
||||
var_split = self.transformer.split_variables(instance, model)
|
||||
x_test = {}
|
||||
|
||||
# Collect training data (x_train) and build x_test
|
||||
for category in var_split.keys():
|
||||
var_index_pairs = var_split[category]
|
||||
x = self.transformer.transform_instance(instance, var_index_pairs)
|
||||
x_test[category] = x
|
||||
if category not in self.x_train.keys():
|
||||
self.x_train[category] = x
|
||||
else:
|
||||
assert x.shape[1] == self.x_train[category].shape[1]
|
||||
self.x_train[category] = np.vstack([self.x_train[category], x])
|
||||
|
||||
# Predict solutions
|
||||
for category in var_split.keys():
|
||||
var_index_pairs = var_split[category]
|
||||
if category in self.predictors.keys():
|
||||
ws = self.predictors[category].predict(x_test[category])
|
||||
assert ws.shape == (len(var_index_pairs), 2)
|
||||
for i in range(len(var_index_pairs)):
|
||||
var, index = var_index_pairs[i]
|
||||
if self.mode == "heuristic":
|
||||
if ws[i,0] == 1:
|
||||
var[index].fix(0)
|
||||
elif ws[i,1] == 1:
|
||||
var[index].fix(1)
|
||||
else:
|
||||
if ws[i,0] == 1:
|
||||
var[index].value = 0
|
||||
elif ws[i,1] == 1:
|
||||
var[index].value = 1
|
||||
|
||||
def after_solve(self, solver, instance, model):
|
||||
var_split = self.transformer.split_variables(instance, model)
|
||||
for category in var_split.keys():
|
||||
var_index_pairs = var_split[category]
|
||||
y = self.transformer.transform_solution(var_index_pairs)
|
||||
if category not in self.y_train.keys():
|
||||
self.y_train[category] = y
|
||||
else:
|
||||
self.y_train[category] = np.vstack([self.y_train[category], y])
|
||||
|
||||
def fit(self, solver):
|
||||
for category in self.x_train.keys():
|
||||
x_train = self.x_train[category]
|
||||
y_train = self.y_train[category]
|
||||
self.predictors[category] = deepcopy(self.predictor_prototype)
|
||||
self.predictors[category].fit(x_train, y_train)
|
||||
|
||||
def merge(self, other):
|
||||
for c in other.x_train.keys():
|
||||
if c not in self.x_train:
|
||||
self.x_train[c] = other.x_train[c]
|
||||
self.y_train[c] = other.y_train[c]
|
||||
else:
|
||||
self.x_train[c] = np.vstack([self.x_train[c], other.x_train[c]])
|
||||
self.y_train[c] = np.vstack([self.y_train[c], other.y_train[c]])
|
||||
if (c in other.predictors.keys()) and (c not in self.predictors.keys()):
|
||||
self.predictors[c] = other.predictors[c]
|
||||
Reference in New Issue
Block a user