Primal: Add end-to-end tests

master
Alinson S. Xavier 5 years ago
parent db2f426140
commit b83911a91d

@ -86,8 +86,8 @@ class PrimalSolutionComponent(Component):
else: else:
self._n_one += 1 self._n_one += 1
logger.info( logger.info(
f"Predicted: {self._n_free} free, {self._n_zero} fix-zero, " f"Predicted: free: {self._n_free}, zero: {self._n_zero}, "
f"{self._n_one} fix-one" f"one: {self._n_one}"
) )
# Provide solution to the solver # Provide solution to the solver
@ -146,8 +146,8 @@ class PrimalSolutionComponent(Component):
thr = self.thresholds[category].predict(xc) thr = self.thresholds[category].predict(xc)
y_pred[category] = np.vstack( y_pred[category] = np.vstack(
[ [
proba[:, 0] > thr[0], proba[:, 0] >= thr[0],
proba[:, 1] > thr[1], proba[:, 1] >= thr[1],
] ]
).T ).T

@ -397,6 +397,7 @@ class LearningSolver:
return stats return stats
def fit(self, training_instances: Union[List[str], List[Instance]]) -> None: def fit(self, training_instances: Union[List[str], List[Instance]]) -> None:
logger.debug("Fitting...")
if len(training_instances) == 0: if len(training_instances) == 0:
return return
for component in self.components.values(): for component in self.components.values():

@ -6,11 +6,14 @@ from unittest.mock import Mock
import numpy as np import numpy as np
from numpy.testing import assert_array_equal from numpy.testing import assert_array_equal
from scipy.stats import randint
from miplearn import Classifier from miplearn import Classifier, LearningSolver, GurobiSolver, GurobiPyomoSolver
from miplearn.classifiers.threshold import Threshold from miplearn.classifiers.threshold import Threshold
from miplearn.components.primal import PrimalSolutionComponent from miplearn.components.primal import PrimalSolutionComponent
from miplearn.problems.tsp import TravelingSalesmanGenerator
from miplearn.types import TrainingSample, Features from miplearn.types import TrainingSample, Features
from tests.fixtures.knapsack import get_knapsack_instance
def test_xy_sample_with_lp_solution() -> None: def test_xy_sample_with_lp_solution() -> None:
@ -210,3 +213,19 @@ def test_fit_xy():
thr.fit.assert_called_once() thr.fit.assert_called_once()
assert_array_equal(x[category], thr.fit.call_args[0][1]) assert_array_equal(x[category], thr.fit.call_args[0][1])
assert_array_equal(y[category], thr.fit.call_args[0][2]) assert_array_equal(y[category], thr.fit.call_args[0][2])
def test_usage():
solver = LearningSolver(
components=[
PrimalSolutionComponent(),
]
)
gen = TravelingSalesmanGenerator(n=randint(low=5, high=6))
instance = gen.generate(1)[0]
solver.solve(instance)
solver.fit([instance])
stats = solver.solve(instance)
assert stats["Primal: free"] == 0
assert stats["Primal: one"] + stats["Primal: zero"] == 10
assert stats["Lower bound"] == stats["Warm start value"]

Loading…
Cancel
Save