Rewrite PrimalSolutionComponent.sample_xy

This commit is contained in:
2021-04-11 21:52:59 -05:00
parent d90d7762e3
commit 2979bd157c
2 changed files with 118 additions and 3 deletions

View File

@@ -20,7 +20,7 @@ from miplearn.classifiers.adaptive import AdaptiveClassifier
from miplearn.classifiers.threshold import MinPrecisionThreshold, Threshold
from miplearn.components import classifier_evaluation_dict
from miplearn.components.component import Component
from miplearn.features import TrainingSample, Features
from miplearn.features import TrainingSample, Features, Sample
from miplearn.instance.base import Instance
from miplearn.types import (
LearningSolveStats,
@@ -179,6 +179,50 @@ class PrimalSolutionComponent(Component):
y[category] += [[opt_value < 0.5, opt_value >= 0.5]]
return x, y
@overrides
def sample_xy(
self,
sample: Sample,
) -> Tuple[Dict[Category, List[List[float]]], Dict[Category, List[List[float]]]]:
x: Dict = {}
y: Dict = {}
assert sample.after_load is not None
assert sample.after_load.variables is not None
for (var_name, var) in sample.after_load.variables.items():
# Initialize categories
category = var.category
if category is None:
continue
if category not in x.keys():
x[category] = []
y[category] = []
# Features
sf = sample.after_load
if sample.after_lp is not None:
sf = sample.after_lp
assert sf.instance is not None
features = list(sf.instance.to_list())
assert sf.variables is not None
assert sf.variables[var_name] is not None
features.extend(sf.variables[var_name].to_list())
x[category].append(features)
# Labels
if sample.after_mip is not None:
assert sample.after_mip.variables is not None
assert sample.after_mip.variables[var_name] is not None
opt_value = sample.after_mip.variables[var_name].value
assert opt_value is not None
assert 0.0 - 1e-5 <= opt_value <= 1.0 + 1e-5, (
f"Variable {var_name} has non-binary value {opt_value} in the "
"optimal solution. Predicting values of non-binary "
"variables is not currently supported. Please set its "
"category to None."
)
y[category].append([opt_value < 0.5, opt_value >= 0.5])
return x, y
@overrides
def sample_evaluate_old(
self,