diff --git a/miplearn/components/steps/drop_redundant.py b/miplearn/components/steps/drop_redundant.py index ba21ee6..1309c20 100644 --- a/miplearn/components/steps/drop_redundant.py +++ b/miplearn/components/steps/drop_redundant.py @@ -60,6 +60,7 @@ class DropRedundantInequalitiesStep(Component): ) y = self.predict(x) + self.pool = [] self.total_dropped = 0 self.total_restored = 0 self.total_kept = 0 @@ -102,7 +103,7 @@ class DropRedundantInequalitiesStep(Component): x = self.x(training_instances) y = self.y(training_instances) logger.debug("Fitting...") - for category in tqdm(x.keys(), desc="Fit (rlx:drop_ineq)"): + for category in tqdm(x.keys(), desc="Fit (drop)"): if category not in self.classifiers: self.classifiers[category] = deepcopy(self.classifier_prototype) self.classifiers[category].fit(x[category], np.array(y[category])) @@ -130,7 +131,7 @@ class DropRedundantInequalitiesStep(Component): x = {} for instance in tqdm( InstanceIterator(instances), - desc="Extract (rlx:drop_ineq:x)", + desc="Extract (drop:x)", disable=len(instances) < 5, ): for training_data in instance.training_data: @@ -153,7 +154,7 @@ class DropRedundantInequalitiesStep(Component): y = {} for instance in tqdm( InstanceIterator(instances), - desc="Extract (rlx:drop_ineq:y)", + desc="Extract (drop:y)", disable=len(instances) < 5, ): for training_data in instance.training_data: @@ -189,9 +190,13 @@ class DropRedundantInequalitiesStep(Component): y_true = self.y([instance]) y_pred = self.predict(x) tp, tn, fp, fn = 0, 0, 0, 0 - for category in y_true.keys(): + for category in tqdm( + y_true.keys(), + disable=len(y_true) < 100, + desc="Eval (drop)", + ): for i in range(len(y_true[category])): - if y_pred[category][i][1] == 1: + if (category in y_pred) and (y_pred[category][i][1] == 1): if y_true[category][i][1] == 1: tp += 1 else: