LearningSolver: add fix_variables option; simplify ws_predictor option

pull/1/head
Alinson S. Xavier 6 years ago
parent 077d5326bc
commit 3bb6e61e82

@ -6,6 +6,7 @@ from .transformers import PerVariableTransformer
from .warmstart import LogisticWarmStartPredictor from .warmstart import LogisticWarmStartPredictor
import pyomo.environ as pe import pyomo.environ as pe
import numpy as np import numpy as np
from copy import deepcopy
class LearningSolver: class LearningSolver:
@ -17,13 +18,15 @@ class LearningSolver:
def __init__(self, def __init__(self,
threads=4, threads=4,
parent_solver=pe.SolverFactory('cbc'), parent_solver=pe.SolverFactory('cbc'),
ws_predictor_factory=LogisticWarmStartPredictor): ws_predictor=LogisticWarmStartPredictor(),
fix_variables=False):
self.parent_solver = parent_solver self.parent_solver = parent_solver
self.parent_solver.options["threads"] = threads self.parent_solver.options["threads"] = threads
self.ws_predictor_factory = ws_predictor_factory self.fix_variables = fix_variables
self.x_train = {} self.x_train = {}
self.y_train = {} self.y_train = {}
self.ws_predictors = {} self.ws_predictors = {}
self.ws_predictor_prototype = ws_predictor
def solve(self, instance, tee=False): def solve(self, instance, tee=False):
# Convert instance into concrete model # Convert instance into concrete model
@ -52,6 +55,12 @@ class LearningSolver:
assert ws.shape == (len(var_index_pairs), 2) assert ws.shape == (len(var_index_pairs), 2)
for i in range(len(var_index_pairs)): for i in range(len(var_index_pairs)):
var, index = var_index_pairs[i] var, index = var_index_pairs[i]
if self.fix_variables:
if ws[i,0] == 1:
var[index].fix(0)
elif ws[i,1] == 1:
var[index].fix(1)
else:
if ws[i,0] == 1: if ws[i,0] == 1:
var[index].value = 0 var[index].value = 0
elif ws[i,1] == 1: elif ws[i,1] == 1:
@ -76,7 +85,7 @@ class LearningSolver:
for category in x_train_dict.keys(): for category in x_train_dict.keys():
x_train = x_train_dict[category] x_train = x_train_dict[category]
y_train = y_train_dict[category] y_train = y_train_dict[category]
self.ws_predictors[category] = self.ws_predictor_factory() self.ws_predictors[category] = deepcopy(self.ws_predictor_prototype)
self.ws_predictors[category].fit(x_train, y_train) self.ws_predictors[category].fit(x_train, y_train)
def _solve(self, model, tee=False): def _solve(self, model, tee=False):

Loading…
Cancel
Save