LearningSolver: add fix_variables option; simplify ws_predictor option

pull/1/head
Alinson S. Xavier 6 years ago
parent 077d5326bc
commit 3bb6e61e82

@ -6,6 +6,7 @@ from .transformers import PerVariableTransformer
from .warmstart import LogisticWarmStartPredictor
import pyomo.environ as pe
import numpy as np
from copy import deepcopy
class LearningSolver:
@ -17,13 +18,15 @@ class LearningSolver:
def __init__(self,
threads=4,
parent_solver=pe.SolverFactory('cbc'),
ws_predictor_factory=LogisticWarmStartPredictor):
ws_predictor=LogisticWarmStartPredictor(),
fix_variables=False):
self.parent_solver = parent_solver
self.parent_solver.options["threads"] = threads
self.ws_predictor_factory = ws_predictor_factory
self.fix_variables = fix_variables
self.x_train = {}
self.y_train = {}
self.ws_predictors = {}
self.ws_predictor_prototype = ws_predictor
def solve(self, instance, tee=False):
# Convert instance into concrete model
@ -52,6 +55,12 @@ class LearningSolver:
assert ws.shape == (len(var_index_pairs), 2)
for i in range(len(var_index_pairs)):
var, index = var_index_pairs[i]
if self.fix_variables:
if ws[i,0] == 1:
var[index].fix(0)
elif ws[i,1] == 1:
var[index].fix(1)
else:
if ws[i,0] == 1:
var[index].value = 0
elif ws[i,1] == 1:
@ -76,7 +85,7 @@ class LearningSolver:
for category in x_train_dict.keys():
x_train = x_train_dict[category]
y_train = y_train_dict[category]
self.ws_predictors[category] = self.ws_predictor_factory()
self.ws_predictors[category] = deepcopy(self.ws_predictor_prototype)
self.ws_predictors[category].fit(x_train, y_train)
def _solve(self, model, tee=False):

Loading…
Cancel
Save