From 3bb6e61e82714551515455152c1b0d7ace3c5dd1 Mon Sep 17 00:00:00 2001 From: Alinson S Xavier Date: Wed, 22 Jan 2020 21:05:22 -0600 Subject: [PATCH] LearningSolver: add fix_variables option; simplify ws_predictor option --- miplearn/solvers.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/miplearn/solvers.py b/miplearn/solvers.py index d624adf..9214b85 100644 --- a/miplearn/solvers.py +++ b/miplearn/solvers.py @@ -6,6 +6,7 @@ from .transformers import PerVariableTransformer from .warmstart import LogisticWarmStartPredictor import pyomo.environ as pe import numpy as np +from copy import deepcopy class LearningSolver: @@ -17,13 +18,15 @@ class LearningSolver: def __init__(self, threads=4, parent_solver=pe.SolverFactory('cbc'), - ws_predictor_factory=LogisticWarmStartPredictor): + ws_predictor=LogisticWarmStartPredictor(), + fix_variables=False): self.parent_solver = parent_solver self.parent_solver.options["threads"] = threads - self.ws_predictor_factory = ws_predictor_factory + self.fix_variables = fix_variables self.x_train = {} self.y_train = {} self.ws_predictors = {} + self.ws_predictor_prototype = ws_predictor def solve(self, instance, tee=False): # Convert instance into concrete model @@ -52,10 +55,16 @@ class LearningSolver: assert ws.shape == (len(var_index_pairs), 2) for i in range(len(var_index_pairs)): var, index = var_index_pairs[i] - if ws[i,0] == 1: - var[index].value = 0 - elif ws[i,1] == 1: - var[index].value = 1 + if self.fix_variables: + if ws[i,0] == 1: + var[index].fix(0) + elif ws[i,1] == 1: + var[index].fix(1) + else: + if ws[i,0] == 1: + var[index].value = 0 + elif ws[i,1] == 1: + var[index].value = 1 # Solve MILP self._solve(model, tee=tee) @@ -76,7 +85,7 @@ class LearningSolver: for category in x_train_dict.keys(): x_train = x_train_dict[category] y_train = y_train_dict[category] - self.ws_predictors[category] = self.ws_predictor_factory() + self.ws_predictors[category] = deepcopy(self.ws_predictor_prototype) self.ws_predictors[category].fit(x_train, y_train) def _solve(self, model, tee=False):