mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
KNN: Make distinction between k and min_samples; improve logging
This commit is contained in:
2
Makefile
2
Makefile
@@ -1,4 +1,4 @@
|
|||||||
PYTEST_ARGS := -W ignore::DeprecationWarning -vv -x
|
PYTEST_ARGS := -W ignore::DeprecationWarning -vv -x --log-level=DEBUG
|
||||||
|
|
||||||
all: docs test
|
all: docs test
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,8 @@ from sklearn.preprocessing import StandardScaler
|
|||||||
from sklearn.model_selection import cross_val_score
|
from sklearn.model_selection import cross_val_score
|
||||||
from sklearn.neighbors import KNeighborsClassifier
|
from sklearn.neighbors import KNeighborsClassifier
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class WarmStartPredictor(ABC):
|
class WarmStartPredictor(ABC):
|
||||||
def __init__(self, thr_clip=[0.50, 0.50]):
|
def __init__(self, thr_clip=[0.50, 0.50]):
|
||||||
@@ -91,30 +92,38 @@ class LogisticWarmStartPredictor(WarmStartPredictor):
|
|||||||
|
|
||||||
|
|
||||||
class KnnWarmStartPredictor(WarmStartPredictor):
|
class KnnWarmStartPredictor(WarmStartPredictor):
|
||||||
def __init__(self, k=50,
|
def __init__(self,
|
||||||
thr_clip=[0.90, 0.90],
|
k=50,
|
||||||
thr_fix=[0.99, 0.99],
|
min_samples=1,
|
||||||
|
thr_clip=[0.80, 0.80],
|
||||||
|
thr_fix=[1.0, 1.0],
|
||||||
):
|
):
|
||||||
super().__init__(thr_clip=thr_clip)
|
super().__init__(thr_clip=thr_clip)
|
||||||
self.k = k
|
self.k = k
|
||||||
self.thr_fix = thr_fix
|
self.thr_fix = thr_fix
|
||||||
|
self.min_samples = min_samples
|
||||||
|
|
||||||
def _fit(self, x_train, y_train, label):
|
def _fit(self, x_train, y_train, label):
|
||||||
y_train_avg = np.average(y_train)
|
y_train_avg = np.average(y_train)
|
||||||
|
|
||||||
# If number of training samples is too small, don't predict anything.
|
# If number of training samples is too small, don't predict anything.
|
||||||
if x_train.shape[0] < self.k:
|
if x_train.shape[0] < self.min_samples:
|
||||||
|
logger.debug("Too few samples; return 0")
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
# If vast majority of observations are true, always return true.
|
# If vast majority of observations are true, always return true.
|
||||||
if y_train_avg > self.thr_fix[label]:
|
if y_train_avg >= self.thr_fix[label]:
|
||||||
|
logger.debug("Consensus reached; return 1")
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
# If vast majority of observations are false, always return false.
|
# If vast majority of observations are false, always return false.
|
||||||
if y_train_avg < (1 - self.thr_fix[label]):
|
if y_train_avg <= (1 - self.thr_fix[label]):
|
||||||
|
logger.debug("Consensus reached; return 0")
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
knn = KNeighborsClassifier(n_neighbors=self.k)
|
logger.debug("Training classifier...")
|
||||||
|
k = min(self.k, x_train.shape[0])
|
||||||
|
knn = KNeighborsClassifier(n_neighbors=k)
|
||||||
knn.fit(x_train, y_train)
|
knn.fit(x_train, y_train)
|
||||||
return knn
|
return knn
|
||||||
|
|
||||||
@@ -143,6 +152,7 @@ class WarmStartComponent(Component):
|
|||||||
vertical=True)
|
vertical=True)
|
||||||
|
|
||||||
# Predict solutions
|
# Predict solutions
|
||||||
|
count_total, count_fixed = 0, 0
|
||||||
var_split = Extractor.split_variables(instance, model)
|
var_split = Extractor.split_variables(instance, model)
|
||||||
for category in var_split.keys():
|
for category in var_split.keys():
|
||||||
var_index_pairs = var_split[category]
|
var_index_pairs = var_split[category]
|
||||||
@@ -152,22 +162,29 @@ class WarmStartComponent(Component):
|
|||||||
assert ws.shape == (len(var_index_pairs), 2)
|
assert ws.shape == (len(var_index_pairs), 2)
|
||||||
for i in range(len(var_index_pairs)):
|
for i in range(len(var_index_pairs)):
|
||||||
var, index = var_index_pairs[i]
|
var, index = var_index_pairs[i]
|
||||||
|
count_total += 1
|
||||||
if self.mode == "heuristic":
|
if self.mode == "heuristic":
|
||||||
if ws[i,0] == 1:
|
if ws[i,0] > 0.5:
|
||||||
var[index].fix(0)
|
var[index].fix(0)
|
||||||
|
count_fixed += 1
|
||||||
if solver.is_persistent:
|
if solver.is_persistent:
|
||||||
solver.internal_solver.update_var(var[index])
|
solver.internal_solver.update_var(var[index])
|
||||||
elif ws[i,1] == 1:
|
elif ws[i,1] > 0.5:
|
||||||
var[index].fix(1)
|
var[index].fix(1)
|
||||||
|
count_fixed += 1
|
||||||
if solver.is_persistent:
|
if solver.is_persistent:
|
||||||
solver.internal_solver.update_var(var[index])
|
solver.internal_solver.update_var(var[index])
|
||||||
else:
|
else:
|
||||||
if ws[i,0] == 1:
|
var[index].value = None
|
||||||
|
if ws[i,0] > 0.5:
|
||||||
|
count_fixed += 1
|
||||||
var[index].value = 0
|
var[index].value = 0
|
||||||
self.is_warm_start_available = True
|
self.is_warm_start_available = True
|
||||||
elif ws[i,1] == 1:
|
elif ws[i,1] > 0.5:
|
||||||
|
count_fixed += 1
|
||||||
var[index].value = 1
|
var[index].value = 1
|
||||||
self.is_warm_start_available = True
|
self.is_warm_start_available = True
|
||||||
|
logger.info("Setting values for %d variables (out of %d)" % (count_fixed, count_total))
|
||||||
|
|
||||||
|
|
||||||
def after_solve(self, solver, instance, model):
|
def after_solve(self, solver, instance, model):
|
||||||
|
|||||||
Reference in New Issue
Block a user