mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Fix all tests
This commit is contained in:
@@ -6,6 +6,7 @@ from unittest.mock import Mock
|
||||
|
||||
import numpy as np
|
||||
from numpy.linalg import norm
|
||||
from numpy.testing import assert_array_equal
|
||||
|
||||
from miplearn.classifiers import Classifier
|
||||
from miplearn.components.lazy_dynamic import DynamicLazyConstraintsComponent
|
||||
@@ -42,15 +43,36 @@ def test_lazy_fit():
|
||||
assert norm(expected_x_train_c - actual_x_train_c) < E
|
||||
|
||||
# Should provide correct y_train to each classifier
|
||||
expected_y_train_a = np.array([1.0, 0.0])
|
||||
expected_y_train_b = np.array([1.0, 1.0])
|
||||
expected_y_train_c = np.array([0.0, 1.0])
|
||||
actual_y_train_a = component.classifiers["a"].fit.call_args[0][1]
|
||||
actual_y_train_b = component.classifiers["b"].fit.call_args[0][1]
|
||||
actual_y_train_c = component.classifiers["c"].fit.call_args[0][1]
|
||||
assert norm(expected_y_train_a - actual_y_train_a) < E
|
||||
assert norm(expected_y_train_b - actual_y_train_b) < E
|
||||
assert norm(expected_y_train_c - actual_y_train_c) < E
|
||||
expected_y_train_a = np.array(
|
||||
[
|
||||
[False, True],
|
||||
[True, False],
|
||||
]
|
||||
)
|
||||
expected_y_train_b = np.array(
|
||||
[
|
||||
[False, True],
|
||||
[False, True],
|
||||
]
|
||||
)
|
||||
expected_y_train_c = np.array(
|
||||
[
|
||||
[True, False],
|
||||
[False, True],
|
||||
]
|
||||
)
|
||||
assert_array_equal(
|
||||
component.classifiers["a"].fit.call_args[0][1],
|
||||
expected_y_train_a,
|
||||
)
|
||||
assert_array_equal(
|
||||
component.classifiers["b"].fit.call_args[0][1],
|
||||
expected_y_train_b,
|
||||
)
|
||||
assert_array_equal(
|
||||
component.classifiers["c"].fit.call_args[0][1],
|
||||
expected_y_train_c,
|
||||
)
|
||||
|
||||
|
||||
def test_lazy_before():
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
import logging
|
||||
import dill
|
||||
import pickle
|
||||
import tempfile
|
||||
import os
|
||||
@@ -44,7 +45,7 @@ def test_learning_solver():
|
||||
|
||||
# Assert solver is picklable
|
||||
with tempfile.TemporaryFile() as file:
|
||||
pickle.dump(solver, file)
|
||||
dill.dump(solver, file)
|
||||
|
||||
|
||||
def test_solve_without_lp():
|
||||
|
||||
Reference in New Issue
Block a user