DropRedundant: Update for new classifier interface

This commit is contained in:
2021-02-02 09:26:16 -06:00
parent d3c5371fa5
commit 8153dfc825
2 changed files with 66 additions and 14 deletions

View File

@@ -45,6 +45,10 @@ class DropRedundantInequalitiesStep(Component):
self.violation_tolerance = violation_tolerance
self.max_iterations = max_iterations
self.current_iteration = 0
self.total_dropped = 0
self.total_restored = 0
self.total_kept = 0
self.total_iterations = 0
def before_solve(self, solver, instance, _):
self.current_iteration = 0
@@ -62,7 +66,7 @@ class DropRedundantInequalitiesStep(Component):
self.total_iterations = 0
for category in y.keys():
for i in range(len(y[category])):
if y[category][i][0] == 1:
if y[category][i][1] == 1:
cid = constraints[category][i]
c = LazyConstraint(
cid=cid,
@@ -101,7 +105,7 @@ class DropRedundantInequalitiesStep(Component):
for category in tqdm(x.keys(), desc="Fit (rlx:drop_ineq)"):
if category not in self.classifiers:
self.classifiers[category] = deepcopy(self.classifier_prototype)
self.classifiers[category].fit(x[category], y[category])
self.classifiers[category].fit(x[category], np.array(y[category]))
@staticmethod
def _x_test(instance, constraint_ids):
@@ -160,9 +164,9 @@ class DropRedundantInequalitiesStep(Component):
if category not in y:
y[category] = []
if slack > self.slack_tolerance:
y[category] += [[1]]
y[category] += [[False, True]]
else:
y[category] += [[0]]
y[category] += [[True, False]]
return y
def predict(self, x):
@@ -175,9 +179,9 @@ class DropRedundantInequalitiesStep(Component):
proba = self.classifiers[category].predict_proba(x_cat)
for i in range(len(proba)):
if proba[i][1] >= self.threshold:
y[category] += [[1]]
y[category] += [[False, True]]
else:
y[category] += [[0]]
y[category] += [[True, False]]
return y
def evaluate(self, instance):
@@ -187,13 +191,13 @@ class DropRedundantInequalitiesStep(Component):
tp, tn, fp, fn = 0, 0, 0, 0
for category in y_true.keys():
for i in range(len(y_true[category])):
if y_pred[category][i][0] == 1:
if y_true[category][i][0] == 1:
if y_pred[category][i][1] == 1:
if y_true[category][i][1] == 1:
tp += 1
else:
fp += 1
else:
if y_true[category][i][0] == 1:
if y_true[category][i][1] == 1:
fn += 1
else:
tn += 1