mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-08 18:38:51 -06:00
Allow user to attach arbitrary data to violations
This commit is contained in:
@@ -10,6 +10,7 @@ import pytest
|
||||
from miplearn.classifiers import Classifier
|
||||
from miplearn.classifiers.threshold import MinProbabilityThreshold
|
||||
from miplearn.components import classifier_evaluation_dict
|
||||
from miplearn.components.dynamic_common import DynamicConstraintsComponent
|
||||
from miplearn.components.dynamic_lazy import DynamicLazyConstraintsComponent
|
||||
from miplearn.features.sample import MemorySample
|
||||
from miplearn.instance.base import Instance
|
||||
@@ -24,13 +25,23 @@ def training_instances() -> List[Instance]:
|
||||
samples_0 = [
|
||||
MemorySample(
|
||||
{
|
||||
"mip_constr_lazy_enforced": np.array(["c1", "c2"], dtype="S"),
|
||||
"mip_constr_lazy": DynamicConstraintsComponent.encode(
|
||||
{
|
||||
b"c1": 0,
|
||||
b"c2": 0,
|
||||
}
|
||||
),
|
||||
"static_instance_features": np.array([5.0]),
|
||||
},
|
||||
),
|
||||
MemorySample(
|
||||
{
|
||||
"mip_constr_lazy_enforced": np.array(["c2", "c3"], dtype="S"),
|
||||
"mip_constr_lazy": DynamicConstraintsComponent.encode(
|
||||
{
|
||||
b"c2": 0,
|
||||
b"c3": 0,
|
||||
}
|
||||
),
|
||||
"static_instance_features": np.array([5.0]),
|
||||
},
|
||||
),
|
||||
@@ -55,7 +66,12 @@ def training_instances() -> List[Instance]:
|
||||
samples_1 = [
|
||||
MemorySample(
|
||||
{
|
||||
"mip_constr_lazy_enforced": np.array(["c3", "c4"], dtype="S"),
|
||||
"mip_constr_lazy": DynamicConstraintsComponent.encode(
|
||||
{
|
||||
b"c3": 0,
|
||||
b"c4": 0,
|
||||
}
|
||||
),
|
||||
"static_instance_features": np.array([8.0]),
|
||||
},
|
||||
)
|
||||
@@ -83,8 +99,8 @@ def test_sample_xy(training_instances: List[Instance]) -> None:
|
||||
comp = DynamicLazyConstraintsComponent()
|
||||
comp.pre_fit(
|
||||
[
|
||||
np.array(["c1", "c3", "c4"], dtype="S"),
|
||||
np.array(["c1", "c2", "c4"], dtype="S"),
|
||||
{b"c1": 0, b"c3": 0, b"c4": 0},
|
||||
{b"c1": 0, b"c2": 0, b"c4": 0},
|
||||
]
|
||||
)
|
||||
x_expected = {
|
||||
@@ -105,7 +121,10 @@ def test_sample_xy(training_instances: List[Instance]) -> None:
|
||||
|
||||
def test_sample_predict_evaluate(training_instances: List[Instance]) -> None:
|
||||
comp = DynamicLazyConstraintsComponent()
|
||||
comp.known_cids.extend([b"c1", b"c2", b"c3", b"c4"])
|
||||
comp.known_violations[b"c1"] = 0
|
||||
comp.known_violations[b"c2"] = 0
|
||||
comp.known_violations[b"c3"] = 0
|
||||
comp.known_violations[b"c4"] = 0
|
||||
comp.thresholds[b"type-a"] = MinProbabilityThreshold([0.5, 0.5])
|
||||
comp.thresholds[b"type-b"] = MinProbabilityThreshold([0.5, 0.5])
|
||||
comp.classifiers[b"type-a"] = Mock(spec=Classifier)
|
||||
|
||||
Reference in New Issue
Block a user