mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Primal: Add end-to-end tests
This commit is contained in:
@@ -6,11 +6,14 @@ from unittest.mock import Mock
|
||||
|
||||
import numpy as np
|
||||
from numpy.testing import assert_array_equal
|
||||
from scipy.stats import randint
|
||||
|
||||
from miplearn import Classifier
|
||||
from miplearn import Classifier, LearningSolver, GurobiSolver, GurobiPyomoSolver
|
||||
from miplearn.classifiers.threshold import Threshold
|
||||
from miplearn.components.primal import PrimalSolutionComponent
|
||||
from miplearn.problems.tsp import TravelingSalesmanGenerator
|
||||
from miplearn.types import TrainingSample, Features
|
||||
from tests.fixtures.knapsack import get_knapsack_instance
|
||||
|
||||
|
||||
def test_xy_sample_with_lp_solution() -> None:
|
||||
@@ -210,3 +213,19 @@ def test_fit_xy():
|
||||
thr.fit.assert_called_once()
|
||||
assert_array_equal(x[category], thr.fit.call_args[0][1])
|
||||
assert_array_equal(y[category], thr.fit.call_args[0][2])
|
||||
|
||||
|
||||
def test_usage():
|
||||
solver = LearningSolver(
|
||||
components=[
|
||||
PrimalSolutionComponent(),
|
||||
]
|
||||
)
|
||||
gen = TravelingSalesmanGenerator(n=randint(low=5, high=6))
|
||||
instance = gen.generate(1)[0]
|
||||
solver.solve(instance)
|
||||
solver.fit([instance])
|
||||
stats = solver.solve(instance)
|
||||
assert stats["Primal: free"] == 0
|
||||
assert stats["Primal: one"] + stats["Primal: zero"] == 10
|
||||
assert stats["Lower bound"] == stats["Warm start value"]
|
||||
|
||||
Reference in New Issue
Block a user