mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Rewrite ObjectiveValueComponent.sample_xy
This commit is contained in:
@@ -7,7 +7,7 @@ from typing import Any, List, TYPE_CHECKING, Tuple, Dict, Hashable
|
||||
import numpy as np
|
||||
from overrides import EnforceOverrides
|
||||
|
||||
from miplearn.features import TrainingSample, Features
|
||||
from miplearn.features import TrainingSample, Features, Sample
|
||||
from miplearn.instance.base import Instance
|
||||
from miplearn.types import LearningSolveStats
|
||||
|
||||
@@ -119,6 +119,14 @@ class Component:
|
||||
"""
|
||||
pass
|
||||
|
||||
def sample_xy(self, sample: Sample) -> Tuple[Dict, Dict]:
|
||||
"""
|
||||
Returns a pair of x and y dictionaries containing, respectively, the matrices
|
||||
of ML features and the labels for the sample. If the training sample does not
|
||||
include label information, returns (x, {}).
|
||||
"""
|
||||
pass
|
||||
|
||||
def xy_instances(
|
||||
self,
|
||||
instances: List[Instance],
|
||||
|
||||
@@ -12,7 +12,7 @@ from sklearn.linear_model import LinearRegression
|
||||
from miplearn.classifiers import Regressor
|
||||
from miplearn.classifiers.sklearn import ScikitLearnRegressor
|
||||
from miplearn.components.component import Component
|
||||
from miplearn.features import TrainingSample, Features
|
||||
from miplearn.features import TrainingSample, Features, Sample
|
||||
from miplearn.instance.base import Instance
|
||||
from miplearn.types import LearningSolveStats
|
||||
|
||||
@@ -98,6 +98,39 @@ class ObjectiveValueComponent(Component):
|
||||
y["Upper bound"] = [[sample.upper_bound]]
|
||||
return x, y
|
||||
|
||||
@overrides
|
||||
def sample_xy(
|
||||
self,
|
||||
sample: Sample,
|
||||
) -> Tuple[Dict[Hashable, List[List[float]]], Dict[Hashable, List[List[float]]]]:
|
||||
# Instance features
|
||||
assert sample.after_load is not None
|
||||
assert sample.after_load.instance is not None
|
||||
f = sample.after_load.instance.to_list()
|
||||
|
||||
# LP solve features
|
||||
if sample.after_lp is not None:
|
||||
assert sample.after_lp.lp_solve is not None
|
||||
f.extend(sample.after_lp.lp_solve.to_list())
|
||||
|
||||
# Features
|
||||
x: Dict[Hashable, List[List[float]]] = {
|
||||
"Upper bound": [f],
|
||||
"Lower bound": [f],
|
||||
}
|
||||
|
||||
# Labels
|
||||
y: Dict[Hashable, List[List[float]]] = {}
|
||||
if sample.after_mip is not None:
|
||||
mip_stats = sample.after_mip.mip_solve
|
||||
assert mip_stats is not None
|
||||
if mip_stats.mip_lower_bound is not None:
|
||||
y["Lower bound"] = [[mip_stats.mip_lower_bound]]
|
||||
if mip_stats.mip_upper_bound is not None:
|
||||
y["Upper bound"] = [[mip_stats.mip_upper_bound]]
|
||||
|
||||
return x, y
|
||||
|
||||
@overrides
|
||||
def sample_evaluate_old(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user