|
|
|
@ -130,6 +130,8 @@ class PrimalSolutionComponent(Component):
|
|
|
|
|
|
|
|
|
|
def before_solve(self, solver, instance, model):
|
|
|
|
|
solution = self.predict(instance)
|
|
|
|
|
if solution is None:
|
|
|
|
|
return
|
|
|
|
|
if self.mode == "heuristic":
|
|
|
|
|
solver.internal_solver.fix(solution)
|
|
|
|
|
else:
|
|
|
|
@ -185,6 +187,7 @@ class PrimalSolutionComponent(Component):
|
|
|
|
|
x_test = VariableFeaturesExtractor().extract([instance])
|
|
|
|
|
solution = {}
|
|
|
|
|
var_split = Extractor.split_variables(instance)
|
|
|
|
|
all_none = True
|
|
|
|
|
for category in var_split.keys():
|
|
|
|
|
for (i, (var, index)) in enumerate(var_split[category]):
|
|
|
|
|
if var not in solution.keys():
|
|
|
|
@ -198,4 +201,8 @@ class PrimalSolutionComponent(Component):
|
|
|
|
|
(var, index, ws[i, 1], self.thresholds[category, label]))
|
|
|
|
|
if ws[i, 1] >= self.thresholds[category, label]:
|
|
|
|
|
solution[var][index] = label
|
|
|
|
|
if all_none:
|
|
|
|
|
all_none = False
|
|
|
|
|
if all_none:
|
|
|
|
|
return None
|
|
|
|
|
return solution
|
|
|
|
|