Rewrite ObjectiveValueComponent.sample_xy

This commit is contained in:
2021-04-11 21:27:25 -05:00
parent 2da60dd293
commit d90d7762e3
5 changed files with 108 additions and 15 deletions

View File

@@ -7,7 +7,7 @@ from typing import Any, List, TYPE_CHECKING, Tuple, Dict, Hashable
import numpy as np
from overrides import EnforceOverrides
from miplearn.features import TrainingSample, Features
from miplearn.features import TrainingSample, Features, Sample
from miplearn.instance.base import Instance
from miplearn.types import LearningSolveStats
@@ -119,6 +119,14 @@ class Component:
"""
pass
def sample_xy(self, sample: Sample) -> Tuple[Dict, Dict]:
"""
Returns a pair of x and y dictionaries containing, respectively, the matrices
of ML features and the labels for the sample. If the training sample does not
include label information, returns (x, {}).
"""
pass
def xy_instances(
self,
instances: List[Instance],

View File

@@ -12,7 +12,7 @@ from sklearn.linear_model import LinearRegression
from miplearn.classifiers import Regressor
from miplearn.classifiers.sklearn import ScikitLearnRegressor
from miplearn.components.component import Component
from miplearn.features import TrainingSample, Features
from miplearn.features import TrainingSample, Features, Sample
from miplearn.instance.base import Instance
from miplearn.types import LearningSolveStats
@@ -98,6 +98,39 @@ class ObjectiveValueComponent(Component):
y["Upper bound"] = [[sample.upper_bound]]
return x, y
@overrides
def sample_xy(
self,
sample: Sample,
) -> Tuple[Dict[Hashable, List[List[float]]], Dict[Hashable, List[List[float]]]]:
# Instance features
assert sample.after_load is not None
assert sample.after_load.instance is not None
f = sample.after_load.instance.to_list()
# LP solve features
if sample.after_lp is not None:
assert sample.after_lp.lp_solve is not None
f.extend(sample.after_lp.lp_solve.to_list())
# Features
x: Dict[Hashable, List[List[float]]] = {
"Upper bound": [f],
"Lower bound": [f],
}
# Labels
y: Dict[Hashable, List[List[float]]] = {}
if sample.after_mip is not None:
mip_stats = sample.after_mip.mip_solve
assert mip_stats is not None
if mip_stats.mip_lower_bound is not None:
y["Lower bound"] = [[mip_stats.mip_lower_bound]]
if mip_stats.mip_upper_bound is not None:
y["Upper bound"] = [[mip_stats.mip_upper_bound]]
return x, y
@overrides
def sample_evaluate_old(
self,