diff --git a/src/python/miplearn/components/primal.py b/src/python/miplearn/components/primal.py index 7bc8d47..a9401a2 100644 --- a/src/python/miplearn/components/primal.py +++ b/src/python/miplearn/components/primal.py @@ -2,21 +2,18 @@ # Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved. # Released under the modified BSD license. See COPYING.md for more details. -from .component import Component -from ..extractors import * - -from abc import ABC, abstractmethod from copy import deepcopy -import numpy as np -from sklearn.pipeline import make_pipeline + from sklearn.linear_model import LogisticRegression -from sklearn.preprocessing import StandardScaler -from sklearn.model_selection import cross_val_score from sklearn.metrics import roc_curve +from sklearn.model_selection import cross_val_score from sklearn.neighbors import KNeighborsClassifier -from tqdm.auto import tqdm -import pyomo.environ as pe -import logging +from sklearn.pipeline import make_pipeline +from sklearn.preprocessing import StandardScaler + +from .component import Component +from ..extractors import * + logger = logging.getLogger(__name__) @@ -130,8 +127,6 @@ class PrimalSolutionComponent(Component): def before_solve(self, solver, instance, model): solution = self.predict(instance) - if solution is None: - return if self.mode == "heuristic": solver.internal_solver.fix(solution) else: @@ -182,12 +177,10 @@ class PrimalSolutionComponent(Component): (thresholds[k], fpr[k], tpr[k])) self.thresholds[category, label] = thresholds[k] - def predict(self, instance): x_test = VariableFeaturesExtractor().extract([instance]) solution = {} var_split = Extractor.split_variables(instance) - all_none = True for category in var_split.keys(): for (i, (var, index)) in enumerate(var_split[category]): if var not in solution.keys(): @@ -201,8 +194,4 @@ class PrimalSolutionComponent(Component): (var, index, ws[i, 1], self.thresholds[category, label])) if ws[i, 1] >= self.thresholds[category, label]: solution[var][index] = label - if all_none: - all_none = False - if all_none: - return None return solution diff --git a/src/python/miplearn/solvers.py b/src/python/miplearn/solvers.py index bee0734..6f0e406 100644 --- a/src/python/miplearn/solvers.py +++ b/src/python/miplearn/solvers.py @@ -37,9 +37,9 @@ class InternalSolver: self.all_vars = None self.instance = None self.is_warm_start_available = False + self.solver = None self.model = None self.sense = None - self.solver = None self.var_name_to_var = {} def solve_lp(self, tee=False): @@ -74,6 +74,7 @@ class InternalSolver: if var[index].fixed: continue var[index].value = None + self.is_warm_start_available = False def get_solution(self): solution = {} @@ -84,7 +85,6 @@ class InternalSolver: return solution def set_warm_start(self, solution): - self.is_warm_start_available = True self.clear_values() count_total, count_fixed = 0, 0 for var_name in solution: @@ -94,6 +94,8 @@ class InternalSolver: var[index].value = solution[var_name][index] if solution[var_name][index] is not None: count_fixed += 1 + if count_fixed > 0: + self.is_warm_start_available = True logger.info("Setting start values for %d variables (out of %d)" % (count_fixed, count_total))