mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 17:38:51 -06:00
Only set warm start flag if at least one variable is set
This commit is contained in:
@@ -2,21 +2,18 @@
|
|||||||
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
|
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
|
||||||
# Released under the modified BSD license. See COPYING.md for more details.
|
# Released under the modified BSD license. See COPYING.md for more details.
|
||||||
|
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
from sklearn.linear_model import LogisticRegression
|
||||||
|
from sklearn.metrics import roc_curve
|
||||||
|
from sklearn.model_selection import cross_val_score
|
||||||
|
from sklearn.neighbors import KNeighborsClassifier
|
||||||
|
from sklearn.pipeline import make_pipeline
|
||||||
|
from sklearn.preprocessing import StandardScaler
|
||||||
|
|
||||||
from .component import Component
|
from .component import Component
|
||||||
from ..extractors import *
|
from ..extractors import *
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from copy import deepcopy
|
|
||||||
import numpy as np
|
|
||||||
from sklearn.pipeline import make_pipeline
|
|
||||||
from sklearn.linear_model import LogisticRegression
|
|
||||||
from sklearn.preprocessing import StandardScaler
|
|
||||||
from sklearn.model_selection import cross_val_score
|
|
||||||
from sklearn.metrics import roc_curve
|
|
||||||
from sklearn.neighbors import KNeighborsClassifier
|
|
||||||
from tqdm.auto import tqdm
|
|
||||||
import pyomo.environ as pe
|
|
||||||
import logging
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -130,8 +127,6 @@ class PrimalSolutionComponent(Component):
|
|||||||
|
|
||||||
def before_solve(self, solver, instance, model):
|
def before_solve(self, solver, instance, model):
|
||||||
solution = self.predict(instance)
|
solution = self.predict(instance)
|
||||||
if solution is None:
|
|
||||||
return
|
|
||||||
if self.mode == "heuristic":
|
if self.mode == "heuristic":
|
||||||
solver.internal_solver.fix(solution)
|
solver.internal_solver.fix(solution)
|
||||||
else:
|
else:
|
||||||
@@ -182,12 +177,10 @@ class PrimalSolutionComponent(Component):
|
|||||||
(thresholds[k], fpr[k], tpr[k]))
|
(thresholds[k], fpr[k], tpr[k]))
|
||||||
self.thresholds[category, label] = thresholds[k]
|
self.thresholds[category, label] = thresholds[k]
|
||||||
|
|
||||||
|
|
||||||
def predict(self, instance):
|
def predict(self, instance):
|
||||||
x_test = VariableFeaturesExtractor().extract([instance])
|
x_test = VariableFeaturesExtractor().extract([instance])
|
||||||
solution = {}
|
solution = {}
|
||||||
var_split = Extractor.split_variables(instance)
|
var_split = Extractor.split_variables(instance)
|
||||||
all_none = True
|
|
||||||
for category in var_split.keys():
|
for category in var_split.keys():
|
||||||
for (i, (var, index)) in enumerate(var_split[category]):
|
for (i, (var, index)) in enumerate(var_split[category]):
|
||||||
if var not in solution.keys():
|
if var not in solution.keys():
|
||||||
@@ -201,8 +194,4 @@ class PrimalSolutionComponent(Component):
|
|||||||
(var, index, ws[i, 1], self.thresholds[category, label]))
|
(var, index, ws[i, 1], self.thresholds[category, label]))
|
||||||
if ws[i, 1] >= self.thresholds[category, label]:
|
if ws[i, 1] >= self.thresholds[category, label]:
|
||||||
solution[var][index] = label
|
solution[var][index] = label
|
||||||
if all_none:
|
|
||||||
all_none = False
|
|
||||||
if all_none:
|
|
||||||
return None
|
|
||||||
return solution
|
return solution
|
||||||
|
|||||||
@@ -37,9 +37,9 @@ class InternalSolver:
|
|||||||
self.all_vars = None
|
self.all_vars = None
|
||||||
self.instance = None
|
self.instance = None
|
||||||
self.is_warm_start_available = False
|
self.is_warm_start_available = False
|
||||||
|
self.solver = None
|
||||||
self.model = None
|
self.model = None
|
||||||
self.sense = None
|
self.sense = None
|
||||||
self.solver = None
|
|
||||||
self.var_name_to_var = {}
|
self.var_name_to_var = {}
|
||||||
|
|
||||||
def solve_lp(self, tee=False):
|
def solve_lp(self, tee=False):
|
||||||
@@ -74,6 +74,7 @@ class InternalSolver:
|
|||||||
if var[index].fixed:
|
if var[index].fixed:
|
||||||
continue
|
continue
|
||||||
var[index].value = None
|
var[index].value = None
|
||||||
|
self.is_warm_start_available = False
|
||||||
|
|
||||||
def get_solution(self):
|
def get_solution(self):
|
||||||
solution = {}
|
solution = {}
|
||||||
@@ -84,7 +85,6 @@ class InternalSolver:
|
|||||||
return solution
|
return solution
|
||||||
|
|
||||||
def set_warm_start(self, solution):
|
def set_warm_start(self, solution):
|
||||||
self.is_warm_start_available = True
|
|
||||||
self.clear_values()
|
self.clear_values()
|
||||||
count_total, count_fixed = 0, 0
|
count_total, count_fixed = 0, 0
|
||||||
for var_name in solution:
|
for var_name in solution:
|
||||||
@@ -94,6 +94,8 @@ class InternalSolver:
|
|||||||
var[index].value = solution[var_name][index]
|
var[index].value = solution[var_name][index]
|
||||||
if solution[var_name][index] is not None:
|
if solution[var_name][index] is not None:
|
||||||
count_fixed += 1
|
count_fixed += 1
|
||||||
|
if count_fixed > 0:
|
||||||
|
self.is_warm_start_available = True
|
||||||
logger.info("Setting start values for %d variables (out of %d)" %
|
logger.info("Setting start values for %d variables (out of %d)" %
|
||||||
(count_fixed, count_total))
|
(count_fixed, count_total))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user