Rewrite PrimalSolutionComponent.sample_xy

This commit is contained in:
2021-04-11 21:52:59 -05:00
parent d90d7762e3
commit 2979bd157c
2 changed files with 118 additions and 3 deletions

View File

@@ -5,6 +5,7 @@ from typing import cast
from unittest.mock import Mock
import numpy as np
import pytest
from numpy.testing import assert_array_equal
from scipy.stats import randint
@@ -12,13 +13,83 @@ from miplearn.classifiers import Classifier
from miplearn.classifiers.threshold import Threshold
from miplearn.components import classifier_evaluation_dict
from miplearn.components.primal import PrimalSolutionComponent
from miplearn.features import TrainingSample, Variable, Features
from miplearn.features import (
TrainingSample,
Variable,
Features,
Sample,
InstanceFeatures,
)
from miplearn.instance.base import Instance
from miplearn.problems.tsp import TravelingSalesmanGenerator
from miplearn.solvers.learning import LearningSolver
def test_xy() -> None:
@pytest.fixture
def sample() -> Sample:
sample = Sample(
after_load=Features(
variables={
"x[0]": Variable(category="default"),
"x[1]": Variable(category=None),
"x[2]": Variable(category="default"),
"x[3]": Variable(category="default"),
},
),
after_lp=Features(
instance=InstanceFeatures(),
variables={
"x[0]": Variable(),
"x[1]": Variable(),
"x[2]": Variable(),
"x[3]": Variable(),
},
),
after_mip=Features(
variables={
"x[0]": Variable(value=0.0),
"x[1]": Variable(value=0.0),
"x[2]": Variable(value=1.0),
"x[3]": Variable(value=0.0),
}
),
)
sample.after_lp.instance.to_list = Mock(return_value=[5.0]) # type: ignore
sample.after_lp.variables["x[0]"].to_list = Mock( # type: ignore
return_value=[0.0, 0.0]
)
sample.after_lp.variables["x[2]"].to_list = Mock( # type: ignore
return_value=[1.0, 0.0]
)
sample.after_lp.variables["x[3]"].to_list = Mock( # type: ignore
return_value=[1.0, 1.0]
)
return sample
def test_xy(sample: Sample) -> None:
x_expected = {
"default": [
[5.0, 0.0, 0.0],
[5.0, 1.0, 0.0],
[5.0, 1.0, 1.0],
]
}
y_expected = {
"default": [
[True, False],
[False, True],
[True, False],
]
}
xy = PrimalSolutionComponent().sample_xy(sample)
assert xy is not None
x_actual, y_actual = xy
assert x_actual == x_expected
assert y_actual == y_expected
def test_xy_old() -> None:
features = Features(
variables={
"x[0]": Variable(