mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Convert TrainingSample to dataclass
This commit is contained in:
@@ -28,7 +28,7 @@ def test_convert_tight_usage():
|
||||
original_upper_bound = stats["Upper bound"]
|
||||
|
||||
# Should collect training data
|
||||
assert instance.training_data[0]["slacks"]["eq_capacity"] == 0.0
|
||||
assert instance.training_data[0].slacks["eq_capacity"] == 0.0
|
||||
|
||||
# Fit and resolve
|
||||
solver.fit([instance])
|
||||
|
||||
@@ -12,6 +12,7 @@ from miplearn.components.steps.drop_redundant import DropRedundantInequalitiesSt
|
||||
from miplearn.instance import Instance
|
||||
from miplearn.solvers.internal import InternalSolver
|
||||
from miplearn.solvers.learning import LearningSolver
|
||||
from miplearn.types import TrainingSample, Features
|
||||
from tests.fixtures.infeasible import get_infeasible_instance
|
||||
from tests.fixtures.redundant import get_instance_with_redundancy
|
||||
|
||||
@@ -85,8 +86,8 @@ def test_drop_redundant():
|
||||
instance=instance,
|
||||
model=None,
|
||||
stats={},
|
||||
features=None,
|
||||
training_data=None,
|
||||
features=Features(),
|
||||
training_data=TrainingSample(),
|
||||
)
|
||||
|
||||
# Should query list of constraints
|
||||
@@ -129,13 +130,13 @@ def test_drop_redundant():
|
||||
)
|
||||
|
||||
# LearningSolver calls after_solve
|
||||
training_data = {}
|
||||
training_data = TrainingSample()
|
||||
component.after_solve_mip(
|
||||
solver=solver,
|
||||
instance=instance,
|
||||
model=None,
|
||||
stats={},
|
||||
features=None,
|
||||
features=Features(),
|
||||
training_data=training_data,
|
||||
)
|
||||
|
||||
@@ -143,7 +144,7 @@ def test_drop_redundant():
|
||||
internal.get_inequality_slacks.assert_called_once()
|
||||
|
||||
# Should store constraint slacks in instance object
|
||||
assert training_data["slacks"] == {
|
||||
assert training_data.slacks == {
|
||||
"c1": 0.5,
|
||||
"c2": 0.0,
|
||||
"c3": 0.0,
|
||||
@@ -166,8 +167,8 @@ def test_drop_redundant_with_check_feasibility():
|
||||
instance=instance,
|
||||
model=None,
|
||||
stats={},
|
||||
features=None,
|
||||
training_data=None,
|
||||
features=Features(),
|
||||
training_data=TrainingSample(),
|
||||
)
|
||||
|
||||
# Assert constraints are extracted
|
||||
@@ -224,14 +225,14 @@ def test_x_y_fit_predict_evaluate():
|
||||
|
||||
# First mock instance
|
||||
instances[0].training_data = [
|
||||
{
|
||||
"slacks": {
|
||||
TrainingSample(
|
||||
slacks={
|
||||
"c1": 0.00,
|
||||
"c2": 0.05,
|
||||
"c3": 0.00,
|
||||
"c4": 30.0,
|
||||
}
|
||||
}
|
||||
)
|
||||
]
|
||||
instances[0].get_constraint_category = Mock(
|
||||
side_effect=lambda cid: {
|
||||
@@ -251,14 +252,14 @@ def test_x_y_fit_predict_evaluate():
|
||||
|
||||
# Second mock instance
|
||||
instances[1].training_data = [
|
||||
{
|
||||
"slacks": {
|
||||
TrainingSample(
|
||||
slacks={
|
||||
"c1": 0.00,
|
||||
"c3": 0.30,
|
||||
"c4": 0.00,
|
||||
"c5": 0.00,
|
||||
}
|
||||
}
|
||||
)
|
||||
]
|
||||
instances[1].get_constraint_category = Mock(
|
||||
side_effect=lambda cid: {
|
||||
@@ -343,22 +344,22 @@ def test_x_y_fit_predict_evaluate():
|
||||
def test_x_multiple_solves():
|
||||
instance = Mock(spec=Instance)
|
||||
instance.training_data = [
|
||||
{
|
||||
"slacks": {
|
||||
TrainingSample(
|
||||
slacks={
|
||||
"c1": 0.00,
|
||||
"c2": 0.05,
|
||||
"c3": 0.00,
|
||||
"c4": 30.0,
|
||||
}
|
||||
},
|
||||
{
|
||||
"slacks": {
|
||||
),
|
||||
TrainingSample(
|
||||
slacks={
|
||||
"c1": 0.00,
|
||||
"c2": 0.00,
|
||||
"c3": 1.00,
|
||||
"c4": 0.0,
|
||||
}
|
||||
},
|
||||
),
|
||||
]
|
||||
instance.get_constraint_category = Mock(
|
||||
side_effect=lambda cid: {
|
||||
|
||||
Reference in New Issue
Block a user