mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Request constraint features/categories in bulk
This commit is contained in:
@@ -6,7 +6,6 @@ from unittest.mock import Mock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from numpy.testing import assert_array_equal
|
||||
|
||||
from miplearn.classifiers import Classifier
|
||||
from miplearn.classifiers.threshold import MinProbabilityThreshold
|
||||
@@ -42,21 +41,21 @@ def training_instances() -> List[Instance]:
|
||||
instances[0].samples[1].after_load.instance.to_list = Mock( # type: ignore
|
||||
return_value=[5.0]
|
||||
)
|
||||
instances[0].get_constraint_category = Mock( # type: ignore
|
||||
side_effect=lambda cid: {
|
||||
instances[0].get_constraint_categories = Mock( # type: ignore
|
||||
return_value={
|
||||
"c1": "type-a",
|
||||
"c2": "type-a",
|
||||
"c3": "type-b",
|
||||
"c4": "type-b",
|
||||
}[cid]
|
||||
}
|
||||
)
|
||||
instances[0].get_constraint_features = Mock( # type: ignore
|
||||
side_effect=lambda cid: {
|
||||
return_value={
|
||||
"c1": [1.0, 2.0, 3.0],
|
||||
"c2": [4.0, 5.0, 6.0],
|
||||
"c3": [1.0, 2.0],
|
||||
"c4": [3.0, 4.0],
|
||||
}[cid]
|
||||
}
|
||||
)
|
||||
instances[1].samples = [
|
||||
Sample(
|
||||
@@ -67,20 +66,20 @@ def training_instances() -> List[Instance]:
|
||||
instances[1].samples[0].after_load.instance.to_list = Mock( # type: ignore
|
||||
return_value=[8.0]
|
||||
)
|
||||
instances[1].get_constraint_category = Mock( # type: ignore
|
||||
side_effect=lambda cid: {
|
||||
instances[1].get_constraint_categories = Mock( # type: ignore
|
||||
return_value={
|
||||
"c1": None,
|
||||
"c2": "type-a",
|
||||
"c3": "type-b",
|
||||
"c4": "type-b",
|
||||
}[cid]
|
||||
}
|
||||
)
|
||||
instances[1].get_constraint_features = Mock( # type: ignore
|
||||
side_effect=lambda cid: {
|
||||
return_value={
|
||||
"c2": [7.0, 8.0, 9.0],
|
||||
"c3": [5.0, 6.0],
|
||||
"c4": [7.0, 8.0],
|
||||
}[cid]
|
||||
}
|
||||
)
|
||||
return instances
|
||||
|
||||
|
||||
@@ -83,7 +83,7 @@ def test_knapsack() -> None:
|
||||
sa_rhs_up=[2.0],
|
||||
senses=["="],
|
||||
slacks=[0.0],
|
||||
user_features=[[0.0]],
|
||||
user_features=[None],
|
||||
),
|
||||
)
|
||||
assert_equals(
|
||||
|
||||
Reference in New Issue
Block a user