mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Convert TrainingSample to dataclass
This commit is contained in:
@@ -28,7 +28,7 @@ def test_convert_tight_usage():
|
||||
original_upper_bound = stats["Upper bound"]
|
||||
|
||||
# Should collect training data
|
||||
assert instance.training_data[0]["slacks"]["eq_capacity"] == 0.0
|
||||
assert instance.training_data[0].slacks["eq_capacity"] == 0.0
|
||||
|
||||
# Fit and resolve
|
||||
solver.fit([instance])
|
||||
|
||||
@@ -12,6 +12,7 @@ from miplearn.components.steps.drop_redundant import DropRedundantInequalitiesSt
|
||||
from miplearn.instance import Instance
|
||||
from miplearn.solvers.internal import InternalSolver
|
||||
from miplearn.solvers.learning import LearningSolver
|
||||
from miplearn.types import TrainingSample, Features
|
||||
from tests.fixtures.infeasible import get_infeasible_instance
|
||||
from tests.fixtures.redundant import get_instance_with_redundancy
|
||||
|
||||
@@ -85,8 +86,8 @@ def test_drop_redundant():
|
||||
instance=instance,
|
||||
model=None,
|
||||
stats={},
|
||||
features=None,
|
||||
training_data=None,
|
||||
features=Features(),
|
||||
training_data=TrainingSample(),
|
||||
)
|
||||
|
||||
# Should query list of constraints
|
||||
@@ -129,13 +130,13 @@ def test_drop_redundant():
|
||||
)
|
||||
|
||||
# LearningSolver calls after_solve
|
||||
training_data = {}
|
||||
training_data = TrainingSample()
|
||||
component.after_solve_mip(
|
||||
solver=solver,
|
||||
instance=instance,
|
||||
model=None,
|
||||
stats={},
|
||||
features=None,
|
||||
features=Features(),
|
||||
training_data=training_data,
|
||||
)
|
||||
|
||||
@@ -143,7 +144,7 @@ def test_drop_redundant():
|
||||
internal.get_inequality_slacks.assert_called_once()
|
||||
|
||||
# Should store constraint slacks in instance object
|
||||
assert training_data["slacks"] == {
|
||||
assert training_data.slacks == {
|
||||
"c1": 0.5,
|
||||
"c2": 0.0,
|
||||
"c3": 0.0,
|
||||
@@ -166,8 +167,8 @@ def test_drop_redundant_with_check_feasibility():
|
||||
instance=instance,
|
||||
model=None,
|
||||
stats={},
|
||||
features=None,
|
||||
training_data=None,
|
||||
features=Features(),
|
||||
training_data=TrainingSample(),
|
||||
)
|
||||
|
||||
# Assert constraints are extracted
|
||||
@@ -224,14 +225,14 @@ def test_x_y_fit_predict_evaluate():
|
||||
|
||||
# First mock instance
|
||||
instances[0].training_data = [
|
||||
{
|
||||
"slacks": {
|
||||
TrainingSample(
|
||||
slacks={
|
||||
"c1": 0.00,
|
||||
"c2": 0.05,
|
||||
"c3": 0.00,
|
||||
"c4": 30.0,
|
||||
}
|
||||
}
|
||||
)
|
||||
]
|
||||
instances[0].get_constraint_category = Mock(
|
||||
side_effect=lambda cid: {
|
||||
@@ -251,14 +252,14 @@ def test_x_y_fit_predict_evaluate():
|
||||
|
||||
# Second mock instance
|
||||
instances[1].training_data = [
|
||||
{
|
||||
"slacks": {
|
||||
TrainingSample(
|
||||
slacks={
|
||||
"c1": 0.00,
|
||||
"c3": 0.30,
|
||||
"c4": 0.00,
|
||||
"c5": 0.00,
|
||||
}
|
||||
}
|
||||
)
|
||||
]
|
||||
instances[1].get_constraint_category = Mock(
|
||||
side_effect=lambda cid: {
|
||||
@@ -343,22 +344,22 @@ def test_x_y_fit_predict_evaluate():
|
||||
def test_x_multiple_solves():
|
||||
instance = Mock(spec=Instance)
|
||||
instance.training_data = [
|
||||
{
|
||||
"slacks": {
|
||||
TrainingSample(
|
||||
slacks={
|
||||
"c1": 0.00,
|
||||
"c2": 0.05,
|
||||
"c3": 0.00,
|
||||
"c4": 30.0,
|
||||
}
|
||||
},
|
||||
{
|
||||
"slacks": {
|
||||
),
|
||||
TrainingSample(
|
||||
slacks={
|
||||
"c1": 0.00,
|
||||
"c2": 0.00,
|
||||
"c3": 1.00,
|
||||
"c4": 0.0,
|
||||
}
|
||||
},
|
||||
),
|
||||
]
|
||||
instance.get_constraint_category = Mock(
|
||||
side_effect=lambda cid: {
|
||||
|
||||
@@ -23,9 +23,9 @@ from miplearn.types import (
|
||||
|
||||
@pytest.fixture
|
||||
def sample() -> TrainingSample:
|
||||
return {
|
||||
"LazyStatic: Enforced": {"c1", "c2", "c4"},
|
||||
}
|
||||
return TrainingSample(
|
||||
lazy_enforced={"c1", "c2", "c4"},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -101,7 +101,7 @@ def test_usage_with_solver(features: Features) -> None:
|
||||
)
|
||||
)
|
||||
|
||||
sample: TrainingSample = {}
|
||||
sample: TrainingSample = TrainingSample()
|
||||
stats: LearningSolveStats = {}
|
||||
|
||||
# LearningSolver calls before_solve_mip
|
||||
@@ -152,7 +152,7 @@ def test_usage_with_solver(features: Features) -> None:
|
||||
)
|
||||
|
||||
# Should update training sample
|
||||
assert sample["LazyStatic: Enforced"] == {"c1", "c2", "c3", "c4"}
|
||||
assert sample.lazy_enforced == {"c1", "c2", "c3", "c4"}
|
||||
|
||||
# Should update stats
|
||||
assert stats["LazyStatic: Removed"] == 1
|
||||
|
||||
@@ -26,27 +26,27 @@ def features() -> Features:
|
||||
|
||||
@pytest.fixture
|
||||
def sample() -> TrainingSample:
|
||||
return {
|
||||
"Lower bound": 1.0,
|
||||
"Upper bound": 2.0,
|
||||
"LP value": 3.0,
|
||||
}
|
||||
return TrainingSample(
|
||||
lower_bound=1.0,
|
||||
upper_bound=2.0,
|
||||
lp_value=3.0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_without_lp() -> TrainingSample:
|
||||
return {
|
||||
"Lower bound": 1.0,
|
||||
"Upper bound": 2.0,
|
||||
}
|
||||
return TrainingSample(
|
||||
lower_bound=1.0,
|
||||
upper_bound=2.0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_without_ub() -> TrainingSample:
|
||||
return {
|
||||
"Lower bound": 1.0,
|
||||
"LP value": 3.0,
|
||||
}
|
||||
return TrainingSample(
|
||||
lower_bound=1.0,
|
||||
lp_value=3.0,
|
||||
)
|
||||
|
||||
|
||||
def test_sample_xy(
|
||||
|
||||
@@ -38,8 +38,8 @@ def test_xy() -> None:
|
||||
}
|
||||
}
|
||||
)
|
||||
sample: TrainingSample = {
|
||||
"Solution": {
|
||||
sample = TrainingSample(
|
||||
solution={
|
||||
"x": {
|
||||
0: 0.0,
|
||||
1: 1.0,
|
||||
@@ -47,7 +47,7 @@ def test_xy() -> None:
|
||||
3: 0.0,
|
||||
}
|
||||
},
|
||||
"LP solution": {
|
||||
lp_solution={
|
||||
"x": {
|
||||
0: 0.1,
|
||||
1: 0.1,
|
||||
@@ -55,7 +55,7 @@ def test_xy() -> None:
|
||||
3: 0.1,
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
x_expected = {
|
||||
"default": [
|
||||
[0.0, 0.0, 0.1],
|
||||
@@ -99,8 +99,8 @@ def test_xy_without_lp_solution() -> None:
|
||||
}
|
||||
}
|
||||
)
|
||||
sample: TrainingSample = {
|
||||
"Solution": {
|
||||
sample = TrainingSample(
|
||||
solution={
|
||||
"x": {
|
||||
0: 0.0,
|
||||
1: 1.0,
|
||||
@@ -108,7 +108,7 @@ def test_xy_without_lp_solution() -> None:
|
||||
3: 0.0,
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
x_expected = {
|
||||
"default": [
|
||||
[0.0, 0.0],
|
||||
@@ -161,15 +161,15 @@ def test_predict() -> None:
|
||||
}
|
||||
}
|
||||
)
|
||||
sample: TrainingSample = {
|
||||
"LP solution": {
|
||||
sample = TrainingSample(
|
||||
lp_solution={
|
||||
"x": {
|
||||
0: 0.1,
|
||||
1: 0.5,
|
||||
2: 0.9,
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
x, _ = PrimalSolutionComponent.sample_xy(features, sample)
|
||||
comp = PrimalSolutionComponent()
|
||||
comp.classifiers = {"default": clf}
|
||||
@@ -254,8 +254,8 @@ def test_evaluate() -> None:
|
||||
}
|
||||
}
|
||||
)
|
||||
sample: TrainingSample = {
|
||||
"Solution": {
|
||||
sample = TrainingSample(
|
||||
solution={
|
||||
"x": {
|
||||
0: 1.0,
|
||||
1: 1.0,
|
||||
@@ -264,7 +264,7 @@ def test_evaluate() -> None:
|
||||
4: 1.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
ev = comp.sample_evaluate(features, sample)
|
||||
assert ev == {
|
||||
0: classifier_evaluation_dict(tp=1, fp=1, tn=3, fn=0),
|
||||
|
||||
Reference in New Issue
Block a user