Rewrite PrimalSolutionComponent.sample_xy

master
Alinson S. Xavier 5 years ago
parent d90d7762e3
commit 2979bd157c
No known key found for this signature in database
GPG Key ID: DCA0DAD4D2F58624

@ -20,7 +20,7 @@ from miplearn.classifiers.adaptive import AdaptiveClassifier
from miplearn.classifiers.threshold import MinPrecisionThreshold, Threshold from miplearn.classifiers.threshold import MinPrecisionThreshold, Threshold
from miplearn.components import classifier_evaluation_dict from miplearn.components import classifier_evaluation_dict
from miplearn.components.component import Component from miplearn.components.component import Component
from miplearn.features import TrainingSample, Features from miplearn.features import TrainingSample, Features, Sample
from miplearn.instance.base import Instance from miplearn.instance.base import Instance
from miplearn.types import ( from miplearn.types import (
LearningSolveStats, LearningSolveStats,
@ -179,6 +179,50 @@ class PrimalSolutionComponent(Component):
y[category] += [[opt_value < 0.5, opt_value >= 0.5]] y[category] += [[opt_value < 0.5, opt_value >= 0.5]]
return x, y return x, y
@overrides
def sample_xy(
self,
sample: Sample,
) -> Tuple[Dict[Category, List[List[float]]], Dict[Category, List[List[float]]]]:
x: Dict = {}
y: Dict = {}
assert sample.after_load is not None
assert sample.after_load.variables is not None
for (var_name, var) in sample.after_load.variables.items():
# Initialize categories
category = var.category
if category is None:
continue
if category not in x.keys():
x[category] = []
y[category] = []
# Features
sf = sample.after_load
if sample.after_lp is not None:
sf = sample.after_lp
assert sf.instance is not None
features = list(sf.instance.to_list())
assert sf.variables is not None
assert sf.variables[var_name] is not None
features.extend(sf.variables[var_name].to_list())
x[category].append(features)
# Labels
if sample.after_mip is not None:
assert sample.after_mip.variables is not None
assert sample.after_mip.variables[var_name] is not None
opt_value = sample.after_mip.variables[var_name].value
assert opt_value is not None
assert 0.0 - 1e-5 <= opt_value <= 1.0 + 1e-5, (
f"Variable {var_name} has non-binary value {opt_value} in the "
"optimal solution. Predicting values of non-binary "
"variables is not currently supported. Please set its "
"category to None."
)
y[category].append([opt_value < 0.5, opt_value >= 0.5])
return x, y
@overrides @overrides
def sample_evaluate_old( def sample_evaluate_old(
self, self,

@ -5,6 +5,7 @@ from typing import cast
from unittest.mock import Mock from unittest.mock import Mock
import numpy as np import numpy as np
import pytest
from numpy.testing import assert_array_equal from numpy.testing import assert_array_equal
from scipy.stats import randint from scipy.stats import randint
@ -12,13 +13,83 @@ from miplearn.classifiers import Classifier
from miplearn.classifiers.threshold import Threshold from miplearn.classifiers.threshold import Threshold
from miplearn.components import classifier_evaluation_dict from miplearn.components import classifier_evaluation_dict
from miplearn.components.primal import PrimalSolutionComponent from miplearn.components.primal import PrimalSolutionComponent
from miplearn.features import TrainingSample, Variable, Features from miplearn.features import (
TrainingSample,
Variable,
Features,
Sample,
InstanceFeatures,
)
from miplearn.instance.base import Instance from miplearn.instance.base import Instance
from miplearn.problems.tsp import TravelingSalesmanGenerator from miplearn.problems.tsp import TravelingSalesmanGenerator
from miplearn.solvers.learning import LearningSolver from miplearn.solvers.learning import LearningSolver
def test_xy() -> None: @pytest.fixture
def sample() -> Sample:
sample = Sample(
after_load=Features(
variables={
"x[0]": Variable(category="default"),
"x[1]": Variable(category=None),
"x[2]": Variable(category="default"),
"x[3]": Variable(category="default"),
},
),
after_lp=Features(
instance=InstanceFeatures(),
variables={
"x[0]": Variable(),
"x[1]": Variable(),
"x[2]": Variable(),
"x[3]": Variable(),
},
),
after_mip=Features(
variables={
"x[0]": Variable(value=0.0),
"x[1]": Variable(value=0.0),
"x[2]": Variable(value=1.0),
"x[3]": Variable(value=0.0),
}
),
)
sample.after_lp.instance.to_list = Mock(return_value=[5.0]) # type: ignore
sample.after_lp.variables["x[0]"].to_list = Mock( # type: ignore
return_value=[0.0, 0.0]
)
sample.after_lp.variables["x[2]"].to_list = Mock( # type: ignore
return_value=[1.0, 0.0]
)
sample.after_lp.variables["x[3]"].to_list = Mock( # type: ignore
return_value=[1.0, 1.0]
)
return sample
def test_xy(sample: Sample) -> None:
x_expected = {
"default": [
[5.0, 0.0, 0.0],
[5.0, 1.0, 0.0],
[5.0, 1.0, 1.0],
]
}
y_expected = {
"default": [
[True, False],
[False, True],
[True, False],
]
}
xy = PrimalSolutionComponent().sample_xy(sample)
assert xy is not None
x_actual, y_actual = xy
assert x_actual == x_expected
assert y_actual == y_expected
def test_xy_old() -> None:
features = Features( features = Features(
variables={ variables={
"x[0]": Variable( "x[0]": Variable(

Loading…
Cancel
Save