mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Fix all tests
This commit is contained in:
@@ -90,6 +90,13 @@ class AdaptiveClassifier(Classifier):
|
||||
n_samples = x_train.shape[0]
|
||||
assert y_train.shape == (n_samples, 2)
|
||||
|
||||
# If almost all samples belong to the same class, return a fixed prediction and
|
||||
# skip all the other steps.
|
||||
if y_train[:, 0].mean() > 0.999 or y_train[:, 1].mean() > 0.999:
|
||||
self.classifier = CountingClassifier()
|
||||
self.classifier.fit(x_train, y_train)
|
||||
return
|
||||
|
||||
best_name, best_clf, best_score = None, None, -float("inf")
|
||||
for (name, specs) in self.candidates.items():
|
||||
if n_samples < specs.min_samples:
|
||||
|
||||
@@ -85,8 +85,10 @@ class DynamicLazyConstraintsComponent(Component):
|
||||
disable=not sys.stdout.isatty(),
|
||||
):
|
||||
logger.debug("Training: %s" % (str(v)))
|
||||
label = np.zeros(len(training_instances))
|
||||
label[violation_to_instance_idx[v]] = 1.0
|
||||
label = [[True, False] for i in training_instances]
|
||||
for idx in violation_to_instance_idx[v]:
|
||||
label[idx] = [False, True]
|
||||
label = np.array(label, dtype=np.bool8)
|
||||
classifier.fit(features, label)
|
||||
|
||||
def predict(self, instance):
|
||||
|
||||
@@ -116,9 +116,11 @@ class ConvertTightIneqsIntoEqsStep(Component):
|
||||
if category not in y:
|
||||
y[category] = []
|
||||
if 0 <= slack <= self.slack_tolerance:
|
||||
y[category] += [[1]]
|
||||
y[category] += [[False, True]]
|
||||
else:
|
||||
y[category] += [[0]]
|
||||
y[category] += [[True, False]]
|
||||
for category in y.keys():
|
||||
y[category] = np.array(y[category], dtype=np.bool8)
|
||||
return y
|
||||
|
||||
def predict(self, x):
|
||||
|
||||
Reference in New Issue
Block a user