mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Primal: Add end-to-end tests
This commit is contained in:
@@ -86,8 +86,8 @@ class PrimalSolutionComponent(Component):
|
||||
else:
|
||||
self._n_one += 1
|
||||
logger.info(
|
||||
f"Predicted: {self._n_free} free, {self._n_zero} fix-zero, "
|
||||
f"{self._n_one} fix-one"
|
||||
f"Predicted: free: {self._n_free}, zero: {self._n_zero}, "
|
||||
f"one: {self._n_one}"
|
||||
)
|
||||
|
||||
# Provide solution to the solver
|
||||
@@ -146,8 +146,8 @@ class PrimalSolutionComponent(Component):
|
||||
thr = self.thresholds[category].predict(xc)
|
||||
y_pred[category] = np.vstack(
|
||||
[
|
||||
proba[:, 0] > thr[0],
|
||||
proba[:, 1] > thr[1],
|
||||
proba[:, 0] >= thr[0],
|
||||
proba[:, 1] >= thr[1],
|
||||
]
|
||||
).T
|
||||
|
||||
|
||||
@@ -397,6 +397,7 @@ class LearningSolver:
|
||||
return stats
|
||||
|
||||
def fit(self, training_instances: Union[List[str], List[Instance]]) -> None:
|
||||
logger.debug("Fitting...")
|
||||
if len(training_instances) == 0:
|
||||
return
|
||||
for component in self.components.values():
|
||||
|
||||
@@ -6,11 +6,14 @@ from unittest.mock import Mock
|
||||
|
||||
import numpy as np
|
||||
from numpy.testing import assert_array_equal
|
||||
from scipy.stats import randint
|
||||
|
||||
from miplearn import Classifier
|
||||
from miplearn import Classifier, LearningSolver, GurobiSolver, GurobiPyomoSolver
|
||||
from miplearn.classifiers.threshold import Threshold
|
||||
from miplearn.components.primal import PrimalSolutionComponent
|
||||
from miplearn.problems.tsp import TravelingSalesmanGenerator
|
||||
from miplearn.types import TrainingSample, Features
|
||||
from tests.fixtures.knapsack import get_knapsack_instance
|
||||
|
||||
|
||||
def test_xy_sample_with_lp_solution() -> None:
|
||||
@@ -210,3 +213,19 @@ def test_fit_xy():
|
||||
thr.fit.assert_called_once()
|
||||
assert_array_equal(x[category], thr.fit.call_args[0][1])
|
||||
assert_array_equal(y[category], thr.fit.call_args[0][2])
|
||||
|
||||
|
||||
def test_usage():
|
||||
solver = LearningSolver(
|
||||
components=[
|
||||
PrimalSolutionComponent(),
|
||||
]
|
||||
)
|
||||
gen = TravelingSalesmanGenerator(n=randint(low=5, high=6))
|
||||
instance = gen.generate(1)[0]
|
||||
solver.solve(instance)
|
||||
solver.fit([instance])
|
||||
stats = solver.solve(instance)
|
||||
assert stats["Primal: free"] == 0
|
||||
assert stats["Primal: one"] + stats["Primal: zero"] == 10
|
||||
assert stats["Lower bound"] == stats["Warm start value"]
|
||||
|
||||
Reference in New Issue
Block a user