Redesign component.evaluate

This commit is contained in:
2021-04-02 08:09:35 -05:00
parent 0c687692f7
commit 0bce2051a8
9 changed files with 221 additions and 178 deletions

View File

@@ -4,20 +4,16 @@
import logging
from typing import (
Union,
Dict,
Callable,
List,
Hashable,
Optional,
Any,
TYPE_CHECKING,
Tuple,
cast,
)
import numpy as np
from tqdm.auto import tqdm
from miplearn.classifiers import Classifier
from miplearn.classifiers.adaptive import AdaptiveClassifier
@@ -72,53 +68,39 @@ class PrimalSolutionComponent(Component):
features: Features,
training_data: TrainingSample,
) -> None:
if len(self.thresholds) > 0:
logger.info("Predicting MIP solution...")
solution = self.predict(
instance.features,
instance.training_data[-1],
)
# Do nothing if models are not trained
if len(self.classifiers) == 0:
return
# Update statistics
stats["Primal: Free"] = 0
stats["Primal: Zero"] = 0
stats["Primal: One"] = 0
for (var, var_dict) in solution.items():
for (idx, value) in var_dict.items():
if value is None:
stats["Primal: Free"] += 1
# Predict solution and provide it to the solver
logger.info("Predicting MIP solution...")
solution = self.sample_predict(features, training_data)
assert solver.internal_solver is not None
if self.mode == "heuristic":
solver.internal_solver.fix(solution)
else:
solver.internal_solver.set_warm_start(solution)
# Update statistics
stats["Primal: Free"] = 0
stats["Primal: Zero"] = 0
stats["Primal: One"] = 0
for (var, var_dict) in solution.items():
for (idx, value) in var_dict.items():
if value is None:
stats["Primal: Free"] += 1
else:
if value < 0.5:
stats["Primal: Zero"] += 1
else:
if value < 0.5:
stats["Primal: Zero"] += 1
else:
stats["Primal: One"] += 1
logger.info(
f"Predicted: free: {stats['Primal: Free']}, "
f"zero: {stats['Primal: Zero']}, "
f"one: {stats['Primal: One']}"
)
stats["Primal: One"] += 1
logger.info(
f"Predicted: free: {stats['Primal: Free']}, "
f"zero: {stats['Primal: Zero']}, "
f"one: {stats['Primal: One']}"
)
# Provide solution to the solver
assert solver.internal_solver is not None
if self.mode == "heuristic":
solver.internal_solver.fix(solution)
else:
solver.internal_solver.set_warm_start(solution)
def fit_xy(
self,
x: Dict[str, np.ndarray],
y: Dict[str, np.ndarray],
) -> None:
for category in x.keys():
clf = self.classifier_prototype.clone()
thr = self.threshold_prototype.clone()
clf.fit(x[category], y[category])
thr.fit(clf, x[category], y[category])
self.classifiers[category] = clf
self.thresholds[category] = thr
def predict(
def sample_predict(
self,
features: Features,
sample: TrainingSample,
@@ -131,7 +113,7 @@ class PrimalSolutionComponent(Component):
solution[var_name][idx] = None
# Compute y_pred
x, _ = self.xy(features, sample)
x, _ = self.sample_xy(features, sample)
y_pred = {}
for category in x.keys():
assert category in self.classifiers, (
@@ -162,55 +144,8 @@ class PrimalSolutionComponent(Component):
return solution
def evaluate(self, instances):
ev = {"Fix zero": {}, "Fix one": {}}
for instance_idx in tqdm(
range(len(instances)),
desc="Evaluate (primal)",
):
instance = instances[instance_idx]
solution_actual = instance.training_data[0]["Solution"]
solution_pred = self.predict(instance, instance.training_data[0])
vars_all, vars_one, vars_zero = set(), set(), set()
pred_one_positive, pred_zero_positive = set(), set()
for (varname, var_dict) in solution_actual.items():
if varname not in solution_pred.keys():
continue
for (idx, value) in var_dict.items():
vars_all.add((varname, idx))
if value > 0.5:
vars_one.add((varname, idx))
else:
vars_zero.add((varname, idx))
if solution_pred[varname][idx] is not None:
if solution_pred[varname][idx] > 0.5:
pred_one_positive.add((varname, idx))
else:
pred_zero_positive.add((varname, idx))
pred_one_negative = vars_all - pred_one_positive
pred_zero_negative = vars_all - pred_zero_positive
tp_zero = len(pred_zero_positive & vars_zero)
fp_zero = len(pred_zero_positive & vars_one)
tn_zero = len(pred_zero_negative & vars_one)
fn_zero = len(pred_zero_negative & vars_zero)
tp_one = len(pred_one_positive & vars_one)
fp_one = len(pred_one_positive & vars_zero)
tn_one = len(pred_one_negative & vars_zero)
fn_one = len(pred_one_negative & vars_one)
ev["Fix zero"][instance_idx] = classifier_evaluation_dict(
tp_zero, tn_zero, fp_zero, fn_zero
)
ev["Fix one"][instance_idx] = classifier_evaluation_dict(
tp_one, tn_one, fp_one, fn_one
)
return ev
@staticmethod
def xy(
def sample_xy(
features: Features,
sample: TrainingSample,
) -> Tuple[Dict, Dict]:
@@ -246,3 +181,59 @@ class PrimalSolutionComponent(Component):
)
y[category] += [[opt_value < 0.5, opt_value >= 0.5]]
return x, y
def sample_evaluate(
self,
features: Features,
sample: TrainingSample,
) -> Dict:
solution_actual = sample["Solution"]
assert solution_actual is not None
solution_pred = self.sample_predict(features, sample)
vars_all, vars_one, vars_zero = set(), set(), set()
pred_one_positive, pred_zero_positive = set(), set()
for (varname, var_dict) in solution_actual.items():
if varname not in solution_pred.keys():
continue
for (idx, value_actual) in var_dict.items():
assert value_actual is not None
vars_all.add((varname, idx))
if value_actual > 0.5:
vars_one.add((varname, idx))
else:
vars_zero.add((varname, idx))
value_pred = solution_pred[varname][idx]
if value_pred is not None:
if value_pred > 0.5:
pred_one_positive.add((varname, idx))
else:
pred_zero_positive.add((varname, idx))
pred_one_negative = vars_all - pred_one_positive
pred_zero_negative = vars_all - pred_zero_positive
return {
0: classifier_evaluation_dict(
tp=len(pred_zero_positive & vars_zero),
tn=len(pred_zero_negative & vars_one),
fp=len(pred_zero_positive & vars_one),
fn=len(pred_zero_negative & vars_zero),
),
1: classifier_evaluation_dict(
tp=len(pred_one_positive & vars_one),
tn=len(pred_one_negative & vars_zero),
fp=len(pred_one_positive & vars_zero),
fn=len(pred_one_negative & vars_one),
),
}
def fit_xy(
self,
x: Dict[str, np.ndarray],
y: Dict[str, np.ndarray],
) -> None:
for category in x.keys():
clf = self.classifier_prototype.clone()
thr = self.threshold_prototype.clone()
clf.fit(x[category], y[category])
thr.fit(clf, x[category], y[category])
self.classifiers[category] = clf
self.thresholds[category] = thr