Make WarmStartComponent use Extractor

This commit is contained in:
2020-02-04 13:29:06 -06:00
parent 17c21382c5
commit dbea4aa988
5 changed files with 88 additions and 49 deletions

View File

@@ -3,8 +3,10 @@
# Written by Alinson S. Xavier <axavier@anl.gov>
from miplearn.problems.knapsack import KnapsackInstance
from miplearn import (UserFeaturesExtractor,
SolutionExtractor)
from miplearn.extractors import (UserFeaturesExtractor,
SolutionExtractor,
CombinedExtractor,
)
import numpy as np
import pyomo.environ as pe
@@ -52,3 +54,16 @@ def test_solution_extractor():
0., 1.,
1., 0.,
]
def test_combined_extractor():
instances = _get_instances()
models = [instance.to_model() for instance in instances]
extractor = CombinedExtractor(extractors=[UserFeaturesExtractor(),
SolutionExtractor()])
features = extractor.extract(instances, models)
assert isinstance(features, dict)
assert "default" in features.keys()
assert isinstance(features["default"], np.ndarray)
assert features["default"].shape == (6, 6)

View File

@@ -24,7 +24,7 @@ def test_warm_start_save_load():
solver.parallel_solve(_get_instances(), n_jobs=2)
solver.fit()
comp = solver.components["warm-start"]
assert comp.x_train["default"].shape == (8, 4)
assert comp.x_train["default"].shape == (8, 6)
assert comp.y_train["default"].shape == (8, 2)
assert "default" in comp.predictors.keys()
solver.save_state(state_file.name)
@@ -32,6 +32,6 @@ def test_warm_start_save_load():
solver = LearningSolver(components={"warm-start": WarmStartComponent()})
solver.load_state(state_file.name)
comp = solver.components["warm-start"]
assert comp.x_train["default"].shape == (8, 4)
assert comp.x_train["default"].shape == (8, 6)
assert comp.y_train["default"].shape == (8, 2)
assert "default" in comp.predictors.keys()