Implement InstanceFeaturesExtractor and ObjectiveValueExtractor

This commit is contained in:
2020-02-23 12:22:40 -06:00
parent b428a4fc36
commit 7de1db047f
3 changed files with 43 additions and 15 deletions

View File

@@ -7,13 +7,14 @@ from miplearn import (LearningSolver,
UserFeaturesExtractor,
SolutionExtractor,
CombinedExtractor,
InstanceFeaturesExtractor
)
import numpy as np
import pyomo.environ as pe
def _get_instances():
return [
instances = [
KnapsackInstance(weights=[1., 2., 3.],
prices=[10., 20., 30.],
capacity=2.5,
@@ -23,10 +24,15 @@ def _get_instances():
capacity=4.5,
),
]
models = [instance.to_model() for instance in instances]
solver = LearningSolver()
for (i, instance) in enumerate(instances):
solver.solve(instances[i], models[i])
return instances, models
def test_user_features():
instances = _get_instances()
def test_user_features_extractor():
instances, models = _get_instances()
extractor = UserFeaturesExtractor()
features = extractor.extract(instances)
assert isinstance(features, dict)
@@ -36,12 +42,7 @@ def test_user_features():
def test_solution_extractor():
instances = _get_instances()
models = [instance.to_model() for instance in instances]
solver = LearningSolver()
for (i, instance) in enumerate(instances):
solver.solve(instances[i], models[i])
instances, models = _get_instances()
features = SolutionExtractor().extract(instances, models)
assert isinstance(features, dict)
assert "default" in features.keys()
@@ -58,12 +59,7 @@ def test_solution_extractor():
def test_combined_extractor():
instances = _get_instances()
models = [instance.to_model() for instance in instances]
solver = LearningSolver()
for (i, instance) in enumerate(instances):
solver.solve(instances[i], models[i])
instances, models = _get_instances()
extractor = CombinedExtractor(extractors=[UserFeaturesExtractor(),
SolutionExtractor()])
features = extractor.extract(instances, models)
@@ -72,3 +68,8 @@ def test_combined_extractor():
assert isinstance(features["default"], np.ndarray)
assert features["default"].shape == (6, 6)
def test_instance_features_extractor():
instances, models = _get_instances()
features = InstanceFeaturesExtractor().extract(instances)
assert features.shape == (2,3)