Implement ObjectiveValueComponent.evaluate

pull/3/head
Alinson S. Xavier 6 years ago
parent b7ff587fb1
commit b1871869a0

@ -1,6 +1,7 @@
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization # MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved. # Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details. # Released under the modified BSD license. See COPYING.md for more details.
from sklearn.metrics import mean_squared_error, explained_variance_score, max_error, mean_absolute_error, r2_score
from .. import Component, InstanceFeaturesExtractor, ObjectiveValueExtractor from .. import Component, InstanceFeaturesExtractor, ObjectiveValueExtractor
from sklearn.linear_model import LinearRegression from sklearn.linear_model import LinearRegression
@ -52,3 +53,28 @@ class ObjectiveValueComponent(Component):
lb = self.lb_regressor.predict(features) lb = self.lb_regressor.predict(features)
ub = self.ub_regressor.predict(features) ub = self.ub_regressor.predict(features)
return np.hstack([lb, ub]) return np.hstack([lb, ub])
def evaluate(self, instances):
y_pred = self.predict(instances)
y_true = np.array([[inst.lower_bound, inst.upper_bound] for inst in instances])
y_true_lb, y_true_ub = y_true[:, 0], y_true[:, 1]
y_pred_lb, y_pred_ub = y_pred[:, 1], y_pred[:, 1]
ev = {
"Lower bound": {
"Mean squared error": mean_squared_error(y_true_lb, y_pred_lb),
"Explained variance": explained_variance_score(y_true_lb, y_pred_lb),
"Max error": max_error(y_true_lb, y_pred_lb),
"Mean absolute error": mean_absolute_error(y_true_lb, y_pred_lb),
"R2": np.round(r2_score(y_true_lb, y_pred_lb), 3),
"Median absolute error": mean_absolute_error(y_true_lb, y_pred_lb),
},
"Upper bound": {
"Mean squared error": mean_squared_error(y_true_ub, y_pred_ub),
"Explained variance": explained_variance_score(y_true_ub, y_pred_ub),
"Max error": max_error(y_true_ub, y_pred_ub),
"Mean absolute error": mean_absolute_error(y_true_ub, y_pred_ub),
"R2": np.round(r2_score(y_true_ub, y_pred_ub), 3),
"Median absolute error": mean_absolute_error(y_true_ub, y_pred_ub),
},
}
return ev

@ -2,28 +2,46 @@
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved. # Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details. # Released under the modified BSD license. See COPYING.md for more details.
from miplearn import ObjectiveValueComponent, LearningSolver from unittest.mock import Mock
from miplearn.problems.knapsack import KnapsackInstance
def _get_instances(): import numpy as np
instances = [ from miplearn import ObjectiveValueComponent
KnapsackInstance( from miplearn.classifiers import Regressor
weights=[23., 26., 20., 18.], from miplearn.tests import get_training_instances_and_models
prices=[505., 352., 458., 220.],
capacity=67.,
),
]
models = [instance.to_model() for instance in instances]
solver = LearningSolver()
for i in range(len(instances)):
solver.solve(instances[i], models[i])
return instances, models
def test_usage(): def test_usage():
instances, models = _get_instances() instances, models = get_training_instances_and_models()
comp = ObjectiveValueComponent() comp = ObjectiveValueComponent()
comp.fit(instances) comp.fit(instances)
assert instances[0].lower_bound == 1183.0 assert instances[0].lower_bound == 1183.0
assert instances[0].upper_bound == 1183.0 assert instances[0].upper_bound == 1183.0
assert comp.predict(instances).tolist() == [[1183.0, 1183.0]] assert comp.predict(instances).tolist() == [[1183.0, 1183.0],
[1070.0, 1070.0]]
def test_obj_evaluate():
instances, models = get_training_instances_and_models()
reg = Mock(spec=Regressor)
reg.predict = Mock(return_value=np.array([[1000.0], [1000.0]]))
comp = ObjectiveValueComponent(regressor=reg)
comp.fit(instances)
ev = comp.evaluate(instances)
assert ev == {
'Lower bound': {
'Explained variance': 0.0,
'Max error': 183.0,
'Mean absolute error': 126.5,
'Mean squared error': 19194.5,
'Median absolute error': 126.5,
'R2': -5.013,
},
'Upper bound': {
'Explained variance': 0.0,
'Max error': 183.0,
'Mean absolute error': 126.5,
'Mean squared error': 19194.5,
'Median absolute error': 126.5,
'R2': -5.013,
}
}

Loading…
Cancel
Save