mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-10 11:28:51 -06:00
Rewrite ObjectiveValueComponent.sample_xy
This commit is contained in:
@@ -12,7 +12,7 @@ from sklearn.linear_model import LinearRegression
|
||||
from miplearn.classifiers import Regressor
|
||||
from miplearn.classifiers.sklearn import ScikitLearnRegressor
|
||||
from miplearn.components.component import Component
|
||||
from miplearn.features import TrainingSample, Features
|
||||
from miplearn.features import TrainingSample, Features, Sample
|
||||
from miplearn.instance.base import Instance
|
||||
from miplearn.types import LearningSolveStats
|
||||
|
||||
@@ -98,6 +98,39 @@ class ObjectiveValueComponent(Component):
|
||||
y["Upper bound"] = [[sample.upper_bound]]
|
||||
return x, y
|
||||
|
||||
@overrides
|
||||
def sample_xy(
|
||||
self,
|
||||
sample: Sample,
|
||||
) -> Tuple[Dict[Hashable, List[List[float]]], Dict[Hashable, List[List[float]]]]:
|
||||
# Instance features
|
||||
assert sample.after_load is not None
|
||||
assert sample.after_load.instance is not None
|
||||
f = sample.after_load.instance.to_list()
|
||||
|
||||
# LP solve features
|
||||
if sample.after_lp is not None:
|
||||
assert sample.after_lp.lp_solve is not None
|
||||
f.extend(sample.after_lp.lp_solve.to_list())
|
||||
|
||||
# Features
|
||||
x: Dict[Hashable, List[List[float]]] = {
|
||||
"Upper bound": [f],
|
||||
"Lower bound": [f],
|
||||
}
|
||||
|
||||
# Labels
|
||||
y: Dict[Hashable, List[List[float]]] = {}
|
||||
if sample.after_mip is not None:
|
||||
mip_stats = sample.after_mip.mip_solve
|
||||
assert mip_stats is not None
|
||||
if mip_stats.mip_lower_bound is not None:
|
||||
y["Lower bound"] = [[mip_stats.mip_lower_bound]]
|
||||
if mip_stats.mip_upper_bound is not None:
|
||||
y["Upper bound"] = [[mip_stats.mip_upper_bound]]
|
||||
|
||||
return x, y
|
||||
|
||||
@overrides
|
||||
def sample_evaluate_old(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user