Rewrite ObjectiveValueComponent.sample_xy

This commit is contained in:
2021-04-11 21:27:25 -05:00
parent 2da60dd293
commit d90d7762e3
5 changed files with 108 additions and 15 deletions

View File

@@ -10,8 +10,9 @@ from numpy.testing import assert_array_equal
from miplearn.classifiers import Regressor
from miplearn.components.objective import ObjectiveValueComponent
from miplearn.features import TrainingSample, InstanceFeatures, Features
from miplearn.features import TrainingSample, InstanceFeatures, Features, Sample
from miplearn.instance.base import Instance
from miplearn.solvers.internal import MIPSolveStats, LPSolveStats
from miplearn.solvers.learning import LearningSolver
from miplearn.solvers.pyomo.gurobi import GurobiPyomoSolver
@@ -41,6 +42,27 @@ def sample_old() -> TrainingSample:
)
@pytest.fixture
def sample() -> Sample:
sample = Sample(
after_load=Features(
instance=InstanceFeatures(),
),
after_lp=Features(
lp_solve=LPSolveStats(),
),
after_mip=Features(
mip_solve=MIPSolveStats(
mip_lower_bound=1.0,
mip_upper_bound=2.0,
)
),
)
sample.after_load.instance.to_list = Mock(return_value=[1.0, 2.0]) # type: ignore
sample.after_lp.lp_solve.to_list = Mock(return_value=[3.0]) # type: ignore
return sample
@pytest.fixture
def sample_without_lp() -> TrainingSample:
return TrainingSample(
@@ -57,10 +79,7 @@ def sample_without_ub_old() -> TrainingSample:
)
def test_sample_xy(
instance: Instance,
sample_old: TrainingSample,
) -> None:
def test_sample_xy(sample: Sample) -> None:
x_expected = {
"Lower bound": [[1.0, 2.0, 3.0]],
"Upper bound": [[1.0, 2.0, 3.0]],
@@ -69,7 +88,7 @@ def test_sample_xy(
"Lower bound": [[1.0]],
"Upper bound": [[2.0]],
}
xy = ObjectiveValueComponent().sample_xy_old(instance, sample_old)
xy = ObjectiveValueComponent().sample_xy(sample)
assert xy is not None
x_actual, y_actual = xy
assert x_actual == x_expected