Reformat source code with Black; add pre-commit hooks and CI checks

This commit is contained in:
2020-12-05 10:59:33 -06:00
parent 3823931382
commit d99600f101
49 changed files with 1291 additions and 972 deletions

View File

@@ -19,10 +19,12 @@ class PrimalSolutionComponent(Component):
A component that predicts primal solutions.
"""
def __init__(self,
classifier=AdaptiveClassifier(),
mode="exact",
threshold=MinPrecisionThreshold(0.98)):
def __init__(
self,
classifier=AdaptiveClassifier(),
mode="exact",
threshold=MinPrecisionThreshold(0.98),
):
self.mode = mode
self.classifiers = {}
self.thresholds = {}
@@ -51,9 +53,10 @@ class PrimalSolutionComponent(Component):
features = VariableFeaturesExtractor().extract(training_instances)
solutions = SolutionExtractor().extract(training_instances)
for category in tqdm(features.keys(),
desc="Fit (primal)",
):
for category in tqdm(
features.keys(),
desc="Fit (primal)",
):
x_train = features[category]
for label in [0, 1]:
y_train = solutions[category][:, label].astype(int)
@@ -74,9 +77,15 @@ class PrimalSolutionComponent(Component):
# Find threshold (dynamic or static)
if isinstance(self.threshold_prototype, DynamicThreshold):
self.thresholds[category, label] = self.threshold_prototype.find(clf, x_train, y_train)
self.thresholds[category, label] = self.threshold_prototype.find(
clf,
x_train,
y_train,
)
else:
self.thresholds[category, label] = deepcopy(self.threshold_prototype)
self.thresholds[category, label] = deepcopy(
self.threshold_prototype
)
self.classifiers[category, label] = clf
@@ -98,18 +107,21 @@ class PrimalSolutionComponent(Component):
ws = np.array([[1 - clf, clf] for _ in range(n)])
else:
ws = clf.predict_proba(x_test[category])
assert ws.shape == (n, 2), "ws.shape should be (%d, 2) not %s" % (n, ws.shape)
assert ws.shape == (n, 2), "ws.shape should be (%d, 2) not %s" % (
n,
ws.shape,
)
for (i, (var, index)) in enumerate(var_split[category]):
if ws[i, 1] >= self.thresholds[category, label]:
solution[var][index] = label
return solution
def evaluate(self, instances):
ev = {"Fix zero": {},
"Fix one": {}}
for instance_idx in tqdm(range(len(instances)),
desc="Evaluate (primal)",
):
ev = {"Fix zero": {}, "Fix one": {}}
for instance_idx in tqdm(
range(len(instances)),
desc="Evaluate (primal)",
):
instance = instances[instance_idx]
solution_actual = instance.solution
solution_pred = self.predict(instance)
@@ -143,6 +155,10 @@ class PrimalSolutionComponent(Component):
tn_one = len(pred_one_negative & vars_zero)
fn_one = len(pred_one_negative & vars_one)
ev["Fix zero"][instance_idx] = classifier_evaluation_dict(tp_zero, tn_zero, fp_zero, fn_zero)
ev["Fix one"][instance_idx] = classifier_evaluation_dict(tp_one, tn_one, fp_one, fn_one)
ev["Fix zero"][instance_idx] = classifier_evaluation_dict(
tp_zero, tn_zero, fp_zero, fn_zero
)
ev["Fix one"][instance_idx] = classifier_evaluation_dict(
tp_one, tn_one, fp_one, fn_one
)
return ev