From b83911a91d492b914b56b9f5e2dc6cf275b30ca6 Mon Sep 17 00:00:00 2001 From: "Alinson S. Xavier" Date: Wed, 31 Mar 2021 12:38:23 -0500 Subject: [PATCH] Primal: Add end-to-end tests --- miplearn/components/primal.py | 8 ++++---- miplearn/solvers/learning.py | 1 + tests/components/test_primal.py | 21 ++++++++++++++++++++- 3 files changed, 25 insertions(+), 5 deletions(-) diff --git a/miplearn/components/primal.py b/miplearn/components/primal.py index b0ef043..8d15c63 100644 --- a/miplearn/components/primal.py +++ b/miplearn/components/primal.py @@ -86,8 +86,8 @@ class PrimalSolutionComponent(Component): else: self._n_one += 1 logger.info( - f"Predicted: {self._n_free} free, {self._n_zero} fix-zero, " - f"{self._n_one} fix-one" + f"Predicted: free: {self._n_free}, zero: {self._n_zero}, " + f"one: {self._n_one}" ) # Provide solution to the solver @@ -146,8 +146,8 @@ class PrimalSolutionComponent(Component): thr = self.thresholds[category].predict(xc) y_pred[category] = np.vstack( [ - proba[:, 0] > thr[0], - proba[:, 1] > thr[1], + proba[:, 0] >= thr[0], + proba[:, 1] >= thr[1], ] ).T diff --git a/miplearn/solvers/learning.py b/miplearn/solvers/learning.py index d16bf6d..982443a 100644 --- a/miplearn/solvers/learning.py +++ b/miplearn/solvers/learning.py @@ -397,6 +397,7 @@ class LearningSolver: return stats def fit(self, training_instances: Union[List[str], List[Instance]]) -> None: + logger.debug("Fitting...") if len(training_instances) == 0: return for component in self.components.values(): diff --git a/tests/components/test_primal.py b/tests/components/test_primal.py index 7f713d9..8ce1aa6 100644 --- a/tests/components/test_primal.py +++ b/tests/components/test_primal.py @@ -6,11 +6,14 @@ from unittest.mock import Mock import numpy as np from numpy.testing import assert_array_equal +from scipy.stats import randint -from miplearn import Classifier +from miplearn import Classifier, LearningSolver, GurobiSolver, GurobiPyomoSolver from miplearn.classifiers.threshold import Threshold from miplearn.components.primal import PrimalSolutionComponent +from miplearn.problems.tsp import TravelingSalesmanGenerator from miplearn.types import TrainingSample, Features +from tests.fixtures.knapsack import get_knapsack_instance def test_xy_sample_with_lp_solution() -> None: @@ -210,3 +213,19 @@ def test_fit_xy(): thr.fit.assert_called_once() assert_array_equal(x[category], thr.fit.call_args[0][1]) assert_array_equal(y[category], thr.fit.call_args[0][2]) + + +def test_usage(): + solver = LearningSolver( + components=[ + PrimalSolutionComponent(), + ] + ) + gen = TravelingSalesmanGenerator(n=randint(low=5, high=6)) + instance = gen.generate(1)[0] + solver.solve(instance) + solver.fit([instance]) + stats = solver.solve(instance) + assert stats["Primal: free"] == 0 + assert stats["Primal: one"] + stats["Primal: zero"] == 10 + assert stats["Lower bound"] == stats["Warm start value"]