Redesign component.evaluate

This commit is contained in:
2021-04-02 08:09:35 -05:00
parent 0c687692f7
commit 0bce2051a8
9 changed files with 221 additions and 178 deletions

View File

@@ -1,7 +1,7 @@
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details.
from typing import Dict
from unittest.mock import Mock
import numpy as np
@@ -10,6 +10,7 @@ from scipy.stats import randint
from miplearn import Classifier, LearningSolver
from miplearn.classifiers.threshold import Threshold
from miplearn.components import classifier_evaluation_dict
from miplearn.components.primal import PrimalSolutionComponent
from miplearn.problems.tsp import TravelingSalesmanGenerator
from miplearn.types import TrainingSample, Features
@@ -69,7 +70,7 @@ def test_xy() -> None:
[True, False],
]
}
xy = PrimalSolutionComponent.xy(features, sample)
xy = PrimalSolutionComponent.sample_xy(features, sample)
assert xy is not None
x_actual, y_actual = xy
assert x_actual == x_expected
@@ -122,7 +123,7 @@ def test_xy_without_lp_solution() -> None:
[True, False],
]
}
xy = PrimalSolutionComponent.xy(features, sample)
xy = PrimalSolutionComponent.sample_xy(features, sample)
assert xy is not None
x_actual, y_actual = xy
assert x_actual == x_expected
@@ -169,11 +170,11 @@ def test_predict() -> None:
}
}
}
x, _ = PrimalSolutionComponent.xy(features, sample)
x, _ = PrimalSolutionComponent.sample_xy(features, sample)
comp = PrimalSolutionComponent()
comp.classifiers = {"default": clf}
comp.thresholds = {"default": thr}
solution_actual = comp.predict(features, sample)
solution_actual = comp.sample_predict(features, sample)
clf.predict_proba.assert_called_once()
assert_array_equal(x["default"], clf.predict_proba.call_args[0][0])
thr.predict.assert_called_once()
@@ -229,3 +230,43 @@ def test_usage():
assert stats["Primal: Free"] == 0
assert stats["Primal: One"] + stats["Primal: Zero"] == 10
assert stats["Lower bound"] == stats["Warm start value"]
def test_evaluate() -> None:
comp = PrimalSolutionComponent()
comp.sample_predict = lambda _, __: { # type: ignore
"x": {
0: 1.0,
1: 0.0,
2: 0.0,
3: None,
4: 1.0,
}
}
features: Features = {
"Variables": {
"x": {
0: {},
1: {},
2: {},
3: {},
4: {},
}
}
}
sample: TrainingSample = {
"Solution": {
"x": {
0: 1.0,
1: 1.0,
2: 0.0,
3: 1.0,
4: 1.0,
}
}
}
ev = comp.sample_evaluate(features, sample)
assert ev == {
0: classifier_evaluation_dict(tp=1, fp=1, tn=3, fn=0),
1: classifier_evaluation_dict(tp=2, fp=0, tn=1, fn=2),
}