From c82de560f4d54fd33934c98e44a7d74c52d3bb0a Mon Sep 17 00:00:00 2001 From: Alinson S Xavier Date: Tue, 4 Feb 2020 11:33:23 -0600 Subject: [PATCH] Implement UserFeaturesExtractor and SolutionExtractor --- miplearn/extractors.py | 70 +++++++++++++++++++++++++++++++ miplearn/tests/test_extractors.py | 54 ++++++++++++++++++++++++ 2 files changed, 124 insertions(+) create mode 100644 miplearn/extractors.py create mode 100644 miplearn/tests/test_extractors.py diff --git a/miplearn/extractors.py b/miplearn/extractors.py new file mode 100644 index 0000000..3b9b4fe --- /dev/null +++ b/miplearn/extractors.py @@ -0,0 +1,70 @@ +# MIPLearn, an extensible framework for Learning-Enhanced Mixed-Integer Optimization +# Copyright (C) 2019-2020 Argonne National Laboratory. All rights reserved. +# Written by Alinson S. Xavier + +import numpy as np +from abc import ABC, abstractmethod +from pyomo.core import Var + + +class Extractor(ABC): + @abstractmethod + def extract(self, instances, models): + pass + + @staticmethod + def split_variables(instance, model): + result = {} + for var in model.component_objects(Var): + for index in var: + category = instance.get_variable_category(var, index) + if category is None: + continue + if category not in result.keys(): + result[category] = [] + result[category] += [(var, index)] + return result + + +class UserFeaturesExtractor(Extractor): + def extract(self, + instances, + models=None, + ): + result = {} + if models is None: + models = [instance.to_model() for instance in instances] + for (index, instance) in enumerate(instances): + model = models[index] + instance_features = instance.get_instance_features() + var_split = self.split_variables(instance, model) + for (category, var_index_pairs) in var_split.items(): + if category not in result.keys(): + result[category] = [] + for (var, index) in var_index_pairs: + result[category] += [np.hstack([ + instance_features, + instance.get_variable_features(var, index), + ])] + for category in result.keys(): + result[category] = np.vstack(result[category]) + return result + + +class SolutionExtractor(Extractor): + def extract(self, instances, models): + result = {} + for (index, instance) in enumerate(instances): + model = models[index] + var_split = self.split_variables(instance, model) + for (category, var_index_pairs) in var_split.items(): + if category not in result.keys(): + result[category] = [] + for (var, index) in var_index_pairs: + result[category] += [[ + 1 - var[index].value, + var[index].value, + ]] + for category in result.keys(): + result[category] = np.vstack(result[category]) + return result \ No newline at end of file diff --git a/miplearn/tests/test_extractors.py b/miplearn/tests/test_extractors.py new file mode 100644 index 0000000..8349db5 --- /dev/null +++ b/miplearn/tests/test_extractors.py @@ -0,0 +1,54 @@ +# MIPLearn, an extensible framework for Learning-Enhanced Mixed-Integer Optimization +# Copyright (C) 2019-2020 Argonne National Laboratory. All rights reserved. +# Written by Alinson S. Xavier + +from miplearn.problems.knapsack import KnapsackInstance +from miplearn import (UserFeaturesExtractor, + SolutionExtractor) +import numpy as np +import pyomo.environ as pe + + +def _get_instances(): + return [ + KnapsackInstance(weights=[1., 2., 3.], + prices=[10., 20., 30.], + capacity=2.5, + ), + KnapsackInstance(weights=[3., 4., 5.], + prices=[20., 30., 40.], + capacity=4.5, + ), + ] + + +def test_user_features(): + instances = _get_instances() + extractor = UserFeaturesExtractor() + features = extractor.extract(instances) + assert isinstance(features, dict) + assert "default" in features.keys() + assert isinstance(features["default"], np.ndarray) + assert features["default"].shape == (6, 4) + + +def test_solution_extractor(): + instances = _get_instances() + models = [instance.to_model() for instance in instances] + for model in models: + solver = pe.SolverFactory("cbc") + solver.solve(model) + extractor = SolutionExtractor() + features = extractor.extract(instances, models) + assert isinstance(features, dict) + assert "default" in features.keys() + assert isinstance(features["default"], np.ndarray) + assert features["default"].shape == (6, 2) + assert features["default"].ravel().tolist() == [ + 1., 0., + 0., 1., + 1., 0., + 1., 0., + 0., 1., + 1., 0., + ]