mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Implement ObjectiveValueComponent.evaluate
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
||||
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
from sklearn.metrics import mean_squared_error, explained_variance_score, max_error, mean_absolute_error, r2_score
|
||||
|
||||
from .. import Component, InstanceFeaturesExtractor, ObjectiveValueExtractor
|
||||
from sklearn.linear_model import LinearRegression
|
||||
@@ -52,3 +53,28 @@ class ObjectiveValueComponent(Component):
|
||||
lb = self.lb_regressor.predict(features)
|
||||
ub = self.ub_regressor.predict(features)
|
||||
return np.hstack([lb, ub])
|
||||
|
||||
def evaluate(self, instances):
|
||||
y_pred = self.predict(instances)
|
||||
y_true = np.array([[inst.lower_bound, inst.upper_bound] for inst in instances])
|
||||
y_true_lb, y_true_ub = y_true[:, 0], y_true[:, 1]
|
||||
y_pred_lb, y_pred_ub = y_pred[:, 1], y_pred[:, 1]
|
||||
ev = {
|
||||
"Lower bound": {
|
||||
"Mean squared error": mean_squared_error(y_true_lb, y_pred_lb),
|
||||
"Explained variance": explained_variance_score(y_true_lb, y_pred_lb),
|
||||
"Max error": max_error(y_true_lb, y_pred_lb),
|
||||
"Mean absolute error": mean_absolute_error(y_true_lb, y_pred_lb),
|
||||
"R2": np.round(r2_score(y_true_lb, y_pred_lb), 3),
|
||||
"Median absolute error": mean_absolute_error(y_true_lb, y_pred_lb),
|
||||
},
|
||||
"Upper bound": {
|
||||
"Mean squared error": mean_squared_error(y_true_ub, y_pred_ub),
|
||||
"Explained variance": explained_variance_score(y_true_ub, y_pred_ub),
|
||||
"Max error": max_error(y_true_ub, y_pred_ub),
|
||||
"Mean absolute error": mean_absolute_error(y_true_ub, y_pred_ub),
|
||||
"R2": np.round(r2_score(y_true_ub, y_pred_ub), 3),
|
||||
"Median absolute error": mean_absolute_error(y_true_ub, y_pred_ub),
|
||||
},
|
||||
}
|
||||
return ev
|
||||
|
||||
@@ -2,28 +2,46 @@
|
||||
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
from miplearn import ObjectiveValueComponent, LearningSolver
|
||||
from miplearn.problems.knapsack import KnapsackInstance
|
||||
from unittest.mock import Mock
|
||||
|
||||
def _get_instances():
|
||||
instances = [
|
||||
KnapsackInstance(
|
||||
weights=[23., 26., 20., 18.],
|
||||
prices=[505., 352., 458., 220.],
|
||||
capacity=67.,
|
||||
),
|
||||
]
|
||||
models = [instance.to_model() for instance in instances]
|
||||
solver = LearningSolver()
|
||||
for i in range(len(instances)):
|
||||
solver.solve(instances[i], models[i])
|
||||
return instances, models
|
||||
import numpy as np
|
||||
from miplearn import ObjectiveValueComponent
|
||||
from miplearn.classifiers import Regressor
|
||||
from miplearn.tests import get_training_instances_and_models
|
||||
|
||||
|
||||
def test_usage():
|
||||
instances, models = _get_instances()
|
||||
instances, models = get_training_instances_and_models()
|
||||
comp = ObjectiveValueComponent()
|
||||
comp.fit(instances)
|
||||
assert instances[0].lower_bound == 1183.0
|
||||
assert instances[0].upper_bound == 1183.0
|
||||
assert comp.predict(instances).tolist() == [[1183.0, 1183.0]]
|
||||
assert comp.predict(instances).tolist() == [[1183.0, 1183.0],
|
||||
[1070.0, 1070.0]]
|
||||
|
||||
|
||||
def test_obj_evaluate():
|
||||
instances, models = get_training_instances_and_models()
|
||||
reg = Mock(spec=Regressor)
|
||||
reg.predict = Mock(return_value=np.array([[1000.0], [1000.0]]))
|
||||
comp = ObjectiveValueComponent(regressor=reg)
|
||||
comp.fit(instances)
|
||||
ev = comp.evaluate(instances)
|
||||
assert ev == {
|
||||
'Lower bound': {
|
||||
'Explained variance': 0.0,
|
||||
'Max error': 183.0,
|
||||
'Mean absolute error': 126.5,
|
||||
'Mean squared error': 19194.5,
|
||||
'Median absolute error': 126.5,
|
||||
'R2': -5.013,
|
||||
},
|
||||
'Upper bound': {
|
||||
'Explained variance': 0.0,
|
||||
'Max error': 183.0,
|
||||
'Mean absolute error': 126.5,
|
||||
'Mean squared error': 19194.5,
|
||||
'Median absolute error': 126.5,
|
||||
'R2': -5.013,
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user