mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Rewrite PrimalSolutionComponent.sample_xy
This commit is contained in:
@@ -20,7 +20,7 @@ from miplearn.classifiers.adaptive import AdaptiveClassifier
|
||||
from miplearn.classifiers.threshold import MinPrecisionThreshold, Threshold
|
||||
from miplearn.components import classifier_evaluation_dict
|
||||
from miplearn.components.component import Component
|
||||
from miplearn.features import TrainingSample, Features
|
||||
from miplearn.features import TrainingSample, Features, Sample
|
||||
from miplearn.instance.base import Instance
|
||||
from miplearn.types import (
|
||||
LearningSolveStats,
|
||||
@@ -179,6 +179,50 @@ class PrimalSolutionComponent(Component):
|
||||
y[category] += [[opt_value < 0.5, opt_value >= 0.5]]
|
||||
return x, y
|
||||
|
||||
@overrides
|
||||
def sample_xy(
|
||||
self,
|
||||
sample: Sample,
|
||||
) -> Tuple[Dict[Category, List[List[float]]], Dict[Category, List[List[float]]]]:
|
||||
x: Dict = {}
|
||||
y: Dict = {}
|
||||
assert sample.after_load is not None
|
||||
assert sample.after_load.variables is not None
|
||||
for (var_name, var) in sample.after_load.variables.items():
|
||||
# Initialize categories
|
||||
category = var.category
|
||||
if category is None:
|
||||
continue
|
||||
if category not in x.keys():
|
||||
x[category] = []
|
||||
y[category] = []
|
||||
|
||||
# Features
|
||||
sf = sample.after_load
|
||||
if sample.after_lp is not None:
|
||||
sf = sample.after_lp
|
||||
assert sf.instance is not None
|
||||
features = list(sf.instance.to_list())
|
||||
assert sf.variables is not None
|
||||
assert sf.variables[var_name] is not None
|
||||
features.extend(sf.variables[var_name].to_list())
|
||||
x[category].append(features)
|
||||
|
||||
# Labels
|
||||
if sample.after_mip is not None:
|
||||
assert sample.after_mip.variables is not None
|
||||
assert sample.after_mip.variables[var_name] is not None
|
||||
opt_value = sample.after_mip.variables[var_name].value
|
||||
assert opt_value is not None
|
||||
assert 0.0 - 1e-5 <= opt_value <= 1.0 + 1e-5, (
|
||||
f"Variable {var_name} has non-binary value {opt_value} in the "
|
||||
"optimal solution. Predicting values of non-binary "
|
||||
"variables is not currently supported. Please set its "
|
||||
"category to None."
|
||||
)
|
||||
y[category].append([opt_value < 0.5, opt_value >= 0.5])
|
||||
return x, y
|
||||
|
||||
@overrides
|
||||
def sample_evaluate_old(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user