|
|
|
@ -6,11 +6,14 @@ from unittest.mock import Mock
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
from numpy.testing import assert_array_equal
|
|
|
|
|
from scipy.stats import randint
|
|
|
|
|
|
|
|
|
|
from miplearn import Classifier
|
|
|
|
|
from miplearn import Classifier, LearningSolver, GurobiSolver, GurobiPyomoSolver
|
|
|
|
|
from miplearn.classifiers.threshold import Threshold
|
|
|
|
|
from miplearn.components.primal import PrimalSolutionComponent
|
|
|
|
|
from miplearn.problems.tsp import TravelingSalesmanGenerator
|
|
|
|
|
from miplearn.types import TrainingSample, Features
|
|
|
|
|
from tests.fixtures.knapsack import get_knapsack_instance
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_xy_sample_with_lp_solution() -> None:
|
|
|
|
@ -210,3 +213,19 @@ def test_fit_xy():
|
|
|
|
|
thr.fit.assert_called_once()
|
|
|
|
|
assert_array_equal(x[category], thr.fit.call_args[0][1])
|
|
|
|
|
assert_array_equal(y[category], thr.fit.call_args[0][2])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_usage():
|
|
|
|
|
solver = LearningSolver(
|
|
|
|
|
components=[
|
|
|
|
|
PrimalSolutionComponent(),
|
|
|
|
|
]
|
|
|
|
|
)
|
|
|
|
|
gen = TravelingSalesmanGenerator(n=randint(low=5, high=6))
|
|
|
|
|
instance = gen.generate(1)[0]
|
|
|
|
|
solver.solve(instance)
|
|
|
|
|
solver.fit([instance])
|
|
|
|
|
stats = solver.solve(instance)
|
|
|
|
|
assert stats["Primal: free"] == 0
|
|
|
|
|
assert stats["Primal: one"] + stats["Primal: zero"] == 10
|
|
|
|
|
assert stats["Lower bound"] == stats["Warm start value"]
|
|
|
|
|