Rewrite PrimalSolutionComponent.sample_xy

master
Alinson S. Xavier 5 years ago
parent d90d7762e3
commit 2979bd157c
No known key found for this signature in database
GPG Key ID: DCA0DAD4D2F58624

@ -20,7 +20,7 @@ from miplearn.classifiers.adaptive import AdaptiveClassifier
from miplearn.classifiers.threshold import MinPrecisionThreshold, Threshold
from miplearn.components import classifier_evaluation_dict
from miplearn.components.component import Component
from miplearn.features import TrainingSample, Features
from miplearn.features import TrainingSample, Features, Sample
from miplearn.instance.base import Instance
from miplearn.types import (
LearningSolveStats,
@ -179,6 +179,50 @@ class PrimalSolutionComponent(Component):
y[category] += [[opt_value < 0.5, opt_value >= 0.5]]
return x, y
@overrides
def sample_xy(
self,
sample: Sample,
) -> Tuple[Dict[Category, List[List[float]]], Dict[Category, List[List[float]]]]:
x: Dict = {}
y: Dict = {}
assert sample.after_load is not None
assert sample.after_load.variables is not None
for (var_name, var) in sample.after_load.variables.items():
# Initialize categories
category = var.category
if category is None:
continue
if category not in x.keys():
x[category] = []
y[category] = []
# Features
sf = sample.after_load
if sample.after_lp is not None:
sf = sample.after_lp
assert sf.instance is not None
features = list(sf.instance.to_list())
assert sf.variables is not None
assert sf.variables[var_name] is not None
features.extend(sf.variables[var_name].to_list())
x[category].append(features)
# Labels
if sample.after_mip is not None:
assert sample.after_mip.variables is not None
assert sample.after_mip.variables[var_name] is not None
opt_value = sample.after_mip.variables[var_name].value
assert opt_value is not None
assert 0.0 - 1e-5 <= opt_value <= 1.0 + 1e-5, (
f"Variable {var_name} has non-binary value {opt_value} in the "
"optimal solution. Predicting values of non-binary "
"variables is not currently supported. Please set its "
"category to None."
)
y[category].append([opt_value < 0.5, opt_value >= 0.5])
return x, y
@overrides
def sample_evaluate_old(
self,

@ -5,6 +5,7 @@ from typing import cast
from unittest.mock import Mock
import numpy as np
import pytest
from numpy.testing import assert_array_equal
from scipy.stats import randint
@ -12,13 +13,83 @@ from miplearn.classifiers import Classifier
from miplearn.classifiers.threshold import Threshold
from miplearn.components import classifier_evaluation_dict
from miplearn.components.primal import PrimalSolutionComponent
from miplearn.features import TrainingSample, Variable, Features
from miplearn.features import (
TrainingSample,
Variable,
Features,
Sample,
InstanceFeatures,
)
from miplearn.instance.base import Instance
from miplearn.problems.tsp import TravelingSalesmanGenerator
from miplearn.solvers.learning import LearningSolver
def test_xy() -> None:
@pytest.fixture
def sample() -> Sample:
sample = Sample(
after_load=Features(
variables={
"x[0]": Variable(category="default"),
"x[1]": Variable(category=None),
"x[2]": Variable(category="default"),
"x[3]": Variable(category="default"),
},
),
after_lp=Features(
instance=InstanceFeatures(),
variables={
"x[0]": Variable(),
"x[1]": Variable(),
"x[2]": Variable(),
"x[3]": Variable(),
},
),
after_mip=Features(
variables={
"x[0]": Variable(value=0.0),
"x[1]": Variable(value=0.0),
"x[2]": Variable(value=1.0),
"x[3]": Variable(value=0.0),
}
),
)
sample.after_lp.instance.to_list = Mock(return_value=[5.0]) # type: ignore
sample.after_lp.variables["x[0]"].to_list = Mock( # type: ignore
return_value=[0.0, 0.0]
)
sample.after_lp.variables["x[2]"].to_list = Mock( # type: ignore
return_value=[1.0, 0.0]
)
sample.after_lp.variables["x[3]"].to_list = Mock( # type: ignore
return_value=[1.0, 1.0]
)
return sample
def test_xy(sample: Sample) -> None:
x_expected = {
"default": [
[5.0, 0.0, 0.0],
[5.0, 1.0, 0.0],
[5.0, 1.0, 1.0],
]
}
y_expected = {
"default": [
[True, False],
[False, True],
[True, False],
]
}
xy = PrimalSolutionComponent().sample_xy(sample)
assert xy is not None
x_actual, y_actual = xy
assert x_actual == x_expected
assert y_actual == y_expected
def test_xy_old() -> None:
features = Features(
variables={
"x[0]": Variable(

Loading…
Cancel
Save