Only set warm start flag if at least one variable is set

pull/3/head
Alinson S. Xavier 6 years ago
parent c8152aab6c
commit 5bb109cfad

@ -2,21 +2,18 @@
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details.
from .component import Component
from ..extractors import *
from abc import ABC, abstractmethod
from copy import deepcopy
import numpy as np
from sklearn.pipeline import make_pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import cross_val_score
from sklearn.metrics import roc_curve
from sklearn.model_selection import cross_val_score
from sklearn.neighbors import KNeighborsClassifier
from tqdm.auto import tqdm
import pyomo.environ as pe
import logging
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from .component import Component
from ..extractors import *
logger = logging.getLogger(__name__)
@ -130,8 +127,6 @@ class PrimalSolutionComponent(Component):
def before_solve(self, solver, instance, model):
solution = self.predict(instance)
if solution is None:
return
if self.mode == "heuristic":
solver.internal_solver.fix(solution)
else:
@ -182,12 +177,10 @@ class PrimalSolutionComponent(Component):
(thresholds[k], fpr[k], tpr[k]))
self.thresholds[category, label] = thresholds[k]
def predict(self, instance):
x_test = VariableFeaturesExtractor().extract([instance])
solution = {}
var_split = Extractor.split_variables(instance)
all_none = True
for category in var_split.keys():
for (i, (var, index)) in enumerate(var_split[category]):
if var not in solution.keys():
@ -201,8 +194,4 @@ class PrimalSolutionComponent(Component):
(var, index, ws[i, 1], self.thresholds[category, label]))
if ws[i, 1] >= self.thresholds[category, label]:
solution[var][index] = label
if all_none:
all_none = False
if all_none:
return None
return solution

@ -37,9 +37,9 @@ class InternalSolver:
self.all_vars = None
self.instance = None
self.is_warm_start_available = False
self.solver = None
self.model = None
self.sense = None
self.solver = None
self.var_name_to_var = {}
def solve_lp(self, tee=False):
@ -74,6 +74,7 @@ class InternalSolver:
if var[index].fixed:
continue
var[index].value = None
self.is_warm_start_available = False
def get_solution(self):
solution = {}
@ -84,7 +85,6 @@ class InternalSolver:
return solution
def set_warm_start(self, solution):
self.is_warm_start_available = True
self.clear_values()
count_total, count_fixed = 0, 0
for var_name in solution:
@ -94,6 +94,8 @@ class InternalSolver:
var[index].value = solution[var_name][index]
if solution[var_name][index] is not None:
count_fixed += 1
if count_fixed > 0:
self.is_warm_start_available = True
logger.info("Setting start values for %d variables (out of %d)" %
(count_fixed, count_total))

Loading…
Cancel
Save