|
|
|
@ -4,6 +4,7 @@
|
|
|
|
|
|
|
|
|
|
from . import Component
|
|
|
|
|
from .transformers import PerVariableTransformer
|
|
|
|
|
from .extractors import *
|
|
|
|
|
|
|
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
|
from copy import deepcopy
|
|
|
|
@ -131,24 +132,21 @@ class WarmStartComponent(Component):
|
|
|
|
|
self.predictor_prototype = predictor_prototype
|
|
|
|
|
|
|
|
|
|
def before_solve(self, solver, instance, model):
|
|
|
|
|
var_split = self.transformer.split_variables(instance, model)
|
|
|
|
|
x_test = {}
|
|
|
|
|
# Build x_test
|
|
|
|
|
x_test = CombinedExtractor([UserFeaturesExtractor(),
|
|
|
|
|
SolutionExtractor(),
|
|
|
|
|
]).extract([instance], [model])
|
|
|
|
|
|
|
|
|
|
# Collect training data (x_train) and build x_test
|
|
|
|
|
for category in var_split.keys():
|
|
|
|
|
var_index_pairs = var_split[category]
|
|
|
|
|
x = self.transformer.transform_instance(instance, var_index_pairs)
|
|
|
|
|
x_test[category] = x
|
|
|
|
|
if category not in self.x_train.keys():
|
|
|
|
|
self.x_train[category] = x
|
|
|
|
|
else:
|
|
|
|
|
assert x.shape[1] == self.x_train[category].shape[1]
|
|
|
|
|
self.x_train[category] = np.vstack([self.x_train[category], x])
|
|
|
|
|
# Update self.x_train
|
|
|
|
|
self.x_train = Extractor.merge([self.x_train, x_test],
|
|
|
|
|
vertical=True)
|
|
|
|
|
|
|
|
|
|
# Predict solutions
|
|
|
|
|
var_split = Extractor.split_variables(instance, model)
|
|
|
|
|
for category in var_split.keys():
|
|
|
|
|
var_index_pairs = var_split[category]
|
|
|
|
|
if category in self.predictors.keys():
|
|
|
|
|
if category not in self.predictors.keys():
|
|
|
|
|
continue
|
|
|
|
|
ws = self.predictors[category].predict(x_test[category])
|
|
|
|
|
assert ws.shape == (len(var_index_pairs), 2)
|
|
|
|
|
for i in range(len(var_index_pairs)):
|
|
|
|
@ -169,14 +167,8 @@ class WarmStartComponent(Component):
|
|
|
|
|
var[index].value = 1
|
|
|
|
|
|
|
|
|
|
def after_solve(self, solver, instance, model):
|
|
|
|
|
var_split = self.transformer.split_variables(instance, model)
|
|
|
|
|
for category in var_split.keys():
|
|
|
|
|
var_index_pairs = var_split[category]
|
|
|
|
|
y = self.transformer.transform_solution(var_index_pairs)
|
|
|
|
|
if category not in self.y_train.keys():
|
|
|
|
|
self.y_train[category] = y
|
|
|
|
|
else:
|
|
|
|
|
self.y_train[category] = np.vstack([self.y_train[category], y])
|
|
|
|
|
y_test = SolutionExtractor().extract([instance], [model])
|
|
|
|
|
self.y_train = Extractor.merge([self.y_train, y_test], vertical=True)
|
|
|
|
|
|
|
|
|
|
def fit(self, solver, n_jobs=1):
|
|
|
|
|
for category in tqdm(self.x_train.keys(), desc="Warm start"):
|
|
|
|
|