mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
LearningSolver: add fix_variables option; simplify ws_predictor option
This commit is contained in:
@@ -6,6 +6,7 @@ from .transformers import PerVariableTransformer
|
||||
from .warmstart import LogisticWarmStartPredictor
|
||||
import pyomo.environ as pe
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
class LearningSolver:
|
||||
@@ -17,13 +18,15 @@ class LearningSolver:
|
||||
def __init__(self,
|
||||
threads=4,
|
||||
parent_solver=pe.SolverFactory('cbc'),
|
||||
ws_predictor_factory=LogisticWarmStartPredictor):
|
||||
ws_predictor=LogisticWarmStartPredictor(),
|
||||
fix_variables=False):
|
||||
self.parent_solver = parent_solver
|
||||
self.parent_solver.options["threads"] = threads
|
||||
self.ws_predictor_factory = ws_predictor_factory
|
||||
self.fix_variables = fix_variables
|
||||
self.x_train = {}
|
||||
self.y_train = {}
|
||||
self.ws_predictors = {}
|
||||
self.ws_predictor_prototype = ws_predictor
|
||||
|
||||
def solve(self, instance, tee=False):
|
||||
# Convert instance into concrete model
|
||||
@@ -52,10 +55,16 @@ class LearningSolver:
|
||||
assert ws.shape == (len(var_index_pairs), 2)
|
||||
for i in range(len(var_index_pairs)):
|
||||
var, index = var_index_pairs[i]
|
||||
if ws[i,0] == 1:
|
||||
var[index].value = 0
|
||||
elif ws[i,1] == 1:
|
||||
var[index].value = 1
|
||||
if self.fix_variables:
|
||||
if ws[i,0] == 1:
|
||||
var[index].fix(0)
|
||||
elif ws[i,1] == 1:
|
||||
var[index].fix(1)
|
||||
else:
|
||||
if ws[i,0] == 1:
|
||||
var[index].value = 0
|
||||
elif ws[i,1] == 1:
|
||||
var[index].value = 1
|
||||
|
||||
# Solve MILP
|
||||
self._solve(model, tee=tee)
|
||||
@@ -76,7 +85,7 @@ class LearningSolver:
|
||||
for category in x_train_dict.keys():
|
||||
x_train = x_train_dict[category]
|
||||
y_train = y_train_dict[category]
|
||||
self.ws_predictors[category] = self.ws_predictor_factory()
|
||||
self.ws_predictors[category] = deepcopy(self.ws_predictor_prototype)
|
||||
self.ws_predictors[category].fit(x_train, y_train)
|
||||
|
||||
def _solve(self, model, tee=False):
|
||||
|
||||
Reference in New Issue
Block a user