Primal: reactivate before_solve_mip

This commit is contained in:
2021-03-31 12:07:58 -05:00
parent fe7bad885c
commit db2f426140
7 changed files with 133 additions and 102 deletions

View File

@@ -2,7 +2,6 @@
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details.
from typing import cast
from unittest.mock import Mock
import numpy as np
@@ -11,7 +10,6 @@ from numpy.testing import assert_array_equal
from miplearn import Classifier
from miplearn.classifiers.threshold import Threshold
from miplearn.components.primal import PrimalSolutionComponent
from miplearn.instance import Instance
from miplearn.types import TrainingSample, Features
@@ -142,8 +140,7 @@ def test_predict() -> None:
)
thr = Mock(spec=Threshold)
thr.predict = Mock(return_value=[0.75, 0.75])
instance = cast(Instance, Mock(spec=Instance))
instance.features = {
features: Features = {
"Variables": {
"x": {
0: {
@@ -161,33 +158,23 @@ def test_predict() -> None:
}
}
}
instance.training_data = [
{
"LP solution": {
"x": {
0: 0.1,
1: 0.5,
2: 0.9,
}
sample: TrainingSample = {
"LP solution": {
"x": {
0: 0.1,
1: 0.5,
2: 0.9,
}
}
]
x = {
"default": np.array(
[
[0.0, 0.0, 0.1],
[0.0, 2.0, 0.5],
[2.0, 0.0, 0.9],
]
)
}
x = PrimalSolutionComponent.x_sample(features, sample)
comp = PrimalSolutionComponent()
comp.classifiers = {"default": clf}
comp.thresholds = {"default": thr}
solution_actual = comp.predict(instance)
solution_actual = comp.predict(features, sample)
clf.predict_proba.assert_called_once()
thr.predict.assert_called_once()
assert_array_equal(x["default"], clf.predict_proba.call_args[0][0])
thr.predict.assert_called_once()
assert_array_equal(x["default"], thr.predict.call_args[0][0])
assert solution_actual == {
"x": {
@@ -196,3 +183,30 @@ def test_predict() -> None:
2: 1.0,
}
}
def test_fit_xy():
comp = PrimalSolutionComponent(
classifier=lambda: Mock(spec=Classifier),
threshold=lambda: Mock(spec=Threshold),
)
x = {
"type-a": np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
"type-b": np.array([[7.0, 8.0, 9.0]]),
}
y = {
"type-a": np.array([[True, False], [False, True]]),
"type-b": np.array([[True, False]]),
}
comp.fit_xy(x, y)
for category in ["type-a", "type-b"]:
assert category in comp.classifiers
assert category in comp.thresholds
clf = comp.classifiers[category]
clf.fit.assert_called_once()
assert_array_equal(x[category], clf.fit.call_args[0][0])
assert_array_equal(y[category], clf.fit.call_args[0][1])
thr = comp.thresholds[category]
thr.fit.assert_called_once()
assert_array_equal(x[category], thr.fit.call_args[0][1])
assert_array_equal(y[category], thr.fit.call_args[0][2])

View File

@@ -12,24 +12,25 @@ from miplearn.solvers.learning import LearningSolver
def test_benchmark():
# Generate training and test instances
generator = MaxWeightStableSetGenerator(n=randint(low=25, high=26))
train_instances = generator.generate(5)
test_instances = generator.generate(3)
for n_jobs in [1, 4]:
# Generate training and test instances
generator = MaxWeightStableSetGenerator(n=randint(low=25, high=26))
train_instances = generator.generate(5)
test_instances = generator.generate(3)
# Training phase...
training_solver = LearningSolver()
training_solver.parallel_solve(train_instances, n_jobs=10)
# Solve training instances
training_solver = LearningSolver()
training_solver.parallel_solve(train_instances, n_jobs=n_jobs)
# Test phase...
test_solvers = {
"Strategy A": LearningSolver(),
"Strategy B": LearningSolver(),
}
benchmark = BenchmarkRunner(test_solvers)
benchmark.fit(train_instances)
benchmark.parallel_solve(test_instances, n_jobs=2, n_trials=2)
assert benchmark.results.values.shape == (12, 18)
# Benchmark
test_solvers = {
"Strategy A": LearningSolver(),
"Strategy B": LearningSolver(),
}
benchmark = BenchmarkRunner(test_solvers)
benchmark.fit(train_instances)
benchmark.parallel_solve(test_instances, n_jobs=n_jobs, n_trials=2)
assert benchmark.results.values.shape == (12, 18)
benchmark.write_csv("/tmp/benchmark.csv")
assert os.path.isfile("/tmp/benchmark.csv")
benchmark.write_csv("/tmp/benchmark.csv")
assert os.path.isfile("/tmp/benchmark.csv")