mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-09 19:08:51 -06:00
Convert TrainingSample to dataclass
This commit is contained in:
@@ -23,9 +23,9 @@ from miplearn.types import (
|
||||
|
||||
@pytest.fixture
|
||||
def sample() -> TrainingSample:
|
||||
return {
|
||||
"LazyStatic: Enforced": {"c1", "c2", "c4"},
|
||||
}
|
||||
return TrainingSample(
|
||||
lazy_enforced={"c1", "c2", "c4"},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -101,7 +101,7 @@ def test_usage_with_solver(features: Features) -> None:
|
||||
)
|
||||
)
|
||||
|
||||
sample: TrainingSample = {}
|
||||
sample: TrainingSample = TrainingSample()
|
||||
stats: LearningSolveStats = {}
|
||||
|
||||
# LearningSolver calls before_solve_mip
|
||||
@@ -152,7 +152,7 @@ def test_usage_with_solver(features: Features) -> None:
|
||||
)
|
||||
|
||||
# Should update training sample
|
||||
assert sample["LazyStatic: Enforced"] == {"c1", "c2", "c3", "c4"}
|
||||
assert sample.lazy_enforced == {"c1", "c2", "c3", "c4"}
|
||||
|
||||
# Should update stats
|
||||
assert stats["LazyStatic: Removed"] == 1
|
||||
|
||||
Reference in New Issue
Block a user