Fix all tests

master
Alinson S. Xavier 5 years ago
parent 3ab3bb3c1f
commit b0b013dd0a

@ -90,6 +90,13 @@ class AdaptiveClassifier(Classifier):
n_samples = x_train.shape[0] n_samples = x_train.shape[0]
assert y_train.shape == (n_samples, 2) assert y_train.shape == (n_samples, 2)
# If almost all samples belong to the same class, return a fixed prediction and
# skip all the other steps.
if y_train[:, 0].mean() > 0.999 or y_train[:, 1].mean() > 0.999:
self.classifier = CountingClassifier()
self.classifier.fit(x_train, y_train)
return
best_name, best_clf, best_score = None, None, -float("inf") best_name, best_clf, best_score = None, None, -float("inf")
for (name, specs) in self.candidates.items(): for (name, specs) in self.candidates.items():
if n_samples < specs.min_samples: if n_samples < specs.min_samples:

@ -85,8 +85,10 @@ class DynamicLazyConstraintsComponent(Component):
disable=not sys.stdout.isatty(), disable=not sys.stdout.isatty(),
): ):
logger.debug("Training: %s" % (str(v))) logger.debug("Training: %s" % (str(v)))
label = np.zeros(len(training_instances)) label = [[True, False] for i in training_instances]
label[violation_to_instance_idx[v]] = 1.0 for idx in violation_to_instance_idx[v]:
label[idx] = [False, True]
label = np.array(label, dtype=np.bool8)
classifier.fit(features, label) classifier.fit(features, label)
def predict(self, instance): def predict(self, instance):

@ -116,9 +116,11 @@ class ConvertTightIneqsIntoEqsStep(Component):
if category not in y: if category not in y:
y[category] = [] y[category] = []
if 0 <= slack <= self.slack_tolerance: if 0 <= slack <= self.slack_tolerance:
y[category] += [[1]] y[category] += [[False, True]]
else: else:
y[category] += [[0]] y[category] += [[True, False]]
for category in y.keys():
y[category] = np.array(y[category], dtype=np.bool8)
return y return y
def predict(self, x): def predict(self, x):

@ -6,6 +6,7 @@ from unittest.mock import Mock
import numpy as np import numpy as np
from numpy.linalg import norm from numpy.linalg import norm
from numpy.testing import assert_array_equal
from miplearn.classifiers import Classifier from miplearn.classifiers import Classifier
from miplearn.components.lazy_dynamic import DynamicLazyConstraintsComponent from miplearn.components.lazy_dynamic import DynamicLazyConstraintsComponent
@ -42,15 +43,36 @@ def test_lazy_fit():
assert norm(expected_x_train_c - actual_x_train_c) < E assert norm(expected_x_train_c - actual_x_train_c) < E
# Should provide correct y_train to each classifier # Should provide correct y_train to each classifier
expected_y_train_a = np.array([1.0, 0.0]) expected_y_train_a = np.array(
expected_y_train_b = np.array([1.0, 1.0]) [
expected_y_train_c = np.array([0.0, 1.0]) [False, True],
actual_y_train_a = component.classifiers["a"].fit.call_args[0][1] [True, False],
actual_y_train_b = component.classifiers["b"].fit.call_args[0][1] ]
actual_y_train_c = component.classifiers["c"].fit.call_args[0][1] )
assert norm(expected_y_train_a - actual_y_train_a) < E expected_y_train_b = np.array(
assert norm(expected_y_train_b - actual_y_train_b) < E [
assert norm(expected_y_train_c - actual_y_train_c) < E [False, True],
[False, True],
]
)
expected_y_train_c = np.array(
[
[True, False],
[False, True],
]
)
assert_array_equal(
component.classifiers["a"].fit.call_args[0][1],
expected_y_train_a,
)
assert_array_equal(
component.classifiers["b"].fit.call_args[0][1],
expected_y_train_b,
)
assert_array_equal(
component.classifiers["c"].fit.call_args[0][1],
expected_y_train_c,
)
def test_lazy_before(): def test_lazy_before():

@ -3,6 +3,7 @@
# Released under the modified BSD license. See COPYING.md for more details. # Released under the modified BSD license. See COPYING.md for more details.
import logging import logging
import dill
import pickle import pickle
import tempfile import tempfile
import os import os
@ -44,7 +45,7 @@ def test_learning_solver():
# Assert solver is picklable # Assert solver is picklable
with tempfile.TemporaryFile() as file: with tempfile.TemporaryFile() as file:
pickle.dump(solver, file) dill.dump(solver, file)
def test_solve_without_lp(): def test_solve_without_lp():

Loading…
Cancel
Save