mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-07 18:08:51 -06:00
Reformat source code with Black; add pre-commit hooks and CI checks
This commit is contained in:
@@ -19,10 +19,12 @@ class PrimalSolutionComponent(Component):
|
||||
A component that predicts primal solutions.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
classifier=AdaptiveClassifier(),
|
||||
mode="exact",
|
||||
threshold=MinPrecisionThreshold(0.98)):
|
||||
def __init__(
|
||||
self,
|
||||
classifier=AdaptiveClassifier(),
|
||||
mode="exact",
|
||||
threshold=MinPrecisionThreshold(0.98),
|
||||
):
|
||||
self.mode = mode
|
||||
self.classifiers = {}
|
||||
self.thresholds = {}
|
||||
@@ -51,9 +53,10 @@ class PrimalSolutionComponent(Component):
|
||||
features = VariableFeaturesExtractor().extract(training_instances)
|
||||
solutions = SolutionExtractor().extract(training_instances)
|
||||
|
||||
for category in tqdm(features.keys(),
|
||||
desc="Fit (primal)",
|
||||
):
|
||||
for category in tqdm(
|
||||
features.keys(),
|
||||
desc="Fit (primal)",
|
||||
):
|
||||
x_train = features[category]
|
||||
for label in [0, 1]:
|
||||
y_train = solutions[category][:, label].astype(int)
|
||||
@@ -74,9 +77,15 @@ class PrimalSolutionComponent(Component):
|
||||
|
||||
# Find threshold (dynamic or static)
|
||||
if isinstance(self.threshold_prototype, DynamicThreshold):
|
||||
self.thresholds[category, label] = self.threshold_prototype.find(clf, x_train, y_train)
|
||||
self.thresholds[category, label] = self.threshold_prototype.find(
|
||||
clf,
|
||||
x_train,
|
||||
y_train,
|
||||
)
|
||||
else:
|
||||
self.thresholds[category, label] = deepcopy(self.threshold_prototype)
|
||||
self.thresholds[category, label] = deepcopy(
|
||||
self.threshold_prototype
|
||||
)
|
||||
|
||||
self.classifiers[category, label] = clf
|
||||
|
||||
@@ -98,18 +107,21 @@ class PrimalSolutionComponent(Component):
|
||||
ws = np.array([[1 - clf, clf] for _ in range(n)])
|
||||
else:
|
||||
ws = clf.predict_proba(x_test[category])
|
||||
assert ws.shape == (n, 2), "ws.shape should be (%d, 2) not %s" % (n, ws.shape)
|
||||
assert ws.shape == (n, 2), "ws.shape should be (%d, 2) not %s" % (
|
||||
n,
|
||||
ws.shape,
|
||||
)
|
||||
for (i, (var, index)) in enumerate(var_split[category]):
|
||||
if ws[i, 1] >= self.thresholds[category, label]:
|
||||
solution[var][index] = label
|
||||
return solution
|
||||
|
||||
def evaluate(self, instances):
|
||||
ev = {"Fix zero": {},
|
||||
"Fix one": {}}
|
||||
for instance_idx in tqdm(range(len(instances)),
|
||||
desc="Evaluate (primal)",
|
||||
):
|
||||
ev = {"Fix zero": {}, "Fix one": {}}
|
||||
for instance_idx in tqdm(
|
||||
range(len(instances)),
|
||||
desc="Evaluate (primal)",
|
||||
):
|
||||
instance = instances[instance_idx]
|
||||
solution_actual = instance.solution
|
||||
solution_pred = self.predict(instance)
|
||||
@@ -143,6 +155,10 @@ class PrimalSolutionComponent(Component):
|
||||
tn_one = len(pred_one_negative & vars_zero)
|
||||
fn_one = len(pred_one_negative & vars_one)
|
||||
|
||||
ev["Fix zero"][instance_idx] = classifier_evaluation_dict(tp_zero, tn_zero, fp_zero, fn_zero)
|
||||
ev["Fix one"][instance_idx] = classifier_evaluation_dict(tp_one, tn_one, fp_one, fn_one)
|
||||
ev["Fix zero"][instance_idx] = classifier_evaluation_dict(
|
||||
tp_zero, tn_zero, fp_zero, fn_zero
|
||||
)
|
||||
ev["Fix one"][instance_idx] = classifier_evaluation_dict(
|
||||
tp_one, tn_one, fp_one, fn_one
|
||||
)
|
||||
return ev
|
||||
|
||||
Reference in New Issue
Block a user