mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Primal: Add end-to-end tests
This commit is contained in:
@@ -86,8 +86,8 @@ class PrimalSolutionComponent(Component):
|
|||||||
else:
|
else:
|
||||||
self._n_one += 1
|
self._n_one += 1
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Predicted: {self._n_free} free, {self._n_zero} fix-zero, "
|
f"Predicted: free: {self._n_free}, zero: {self._n_zero}, "
|
||||||
f"{self._n_one} fix-one"
|
f"one: {self._n_one}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Provide solution to the solver
|
# Provide solution to the solver
|
||||||
@@ -146,8 +146,8 @@ class PrimalSolutionComponent(Component):
|
|||||||
thr = self.thresholds[category].predict(xc)
|
thr = self.thresholds[category].predict(xc)
|
||||||
y_pred[category] = np.vstack(
|
y_pred[category] = np.vstack(
|
||||||
[
|
[
|
||||||
proba[:, 0] > thr[0],
|
proba[:, 0] >= thr[0],
|
||||||
proba[:, 1] > thr[1],
|
proba[:, 1] >= thr[1],
|
||||||
]
|
]
|
||||||
).T
|
).T
|
||||||
|
|
||||||
|
|||||||
@@ -397,6 +397,7 @@ class LearningSolver:
|
|||||||
return stats
|
return stats
|
||||||
|
|
||||||
def fit(self, training_instances: Union[List[str], List[Instance]]) -> None:
|
def fit(self, training_instances: Union[List[str], List[Instance]]) -> None:
|
||||||
|
logger.debug("Fitting...")
|
||||||
if len(training_instances) == 0:
|
if len(training_instances) == 0:
|
||||||
return
|
return
|
||||||
for component in self.components.values():
|
for component in self.components.values():
|
||||||
|
|||||||
@@ -6,11 +6,14 @@ from unittest.mock import Mock
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from numpy.testing import assert_array_equal
|
from numpy.testing import assert_array_equal
|
||||||
|
from scipy.stats import randint
|
||||||
|
|
||||||
from miplearn import Classifier
|
from miplearn import Classifier, LearningSolver, GurobiSolver, GurobiPyomoSolver
|
||||||
from miplearn.classifiers.threshold import Threshold
|
from miplearn.classifiers.threshold import Threshold
|
||||||
from miplearn.components.primal import PrimalSolutionComponent
|
from miplearn.components.primal import PrimalSolutionComponent
|
||||||
|
from miplearn.problems.tsp import TravelingSalesmanGenerator
|
||||||
from miplearn.types import TrainingSample, Features
|
from miplearn.types import TrainingSample, Features
|
||||||
|
from tests.fixtures.knapsack import get_knapsack_instance
|
||||||
|
|
||||||
|
|
||||||
def test_xy_sample_with_lp_solution() -> None:
|
def test_xy_sample_with_lp_solution() -> None:
|
||||||
@@ -210,3 +213,19 @@ def test_fit_xy():
|
|||||||
thr.fit.assert_called_once()
|
thr.fit.assert_called_once()
|
||||||
assert_array_equal(x[category], thr.fit.call_args[0][1])
|
assert_array_equal(x[category], thr.fit.call_args[0][1])
|
||||||
assert_array_equal(y[category], thr.fit.call_args[0][2])
|
assert_array_equal(y[category], thr.fit.call_args[0][2])
|
||||||
|
|
||||||
|
|
||||||
|
def test_usage():
|
||||||
|
solver = LearningSolver(
|
||||||
|
components=[
|
||||||
|
PrimalSolutionComponent(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
gen = TravelingSalesmanGenerator(n=randint(low=5, high=6))
|
||||||
|
instance = gen.generate(1)[0]
|
||||||
|
solver.solve(instance)
|
||||||
|
solver.fit([instance])
|
||||||
|
stats = solver.solve(instance)
|
||||||
|
assert stats["Primal: free"] == 0
|
||||||
|
assert stats["Primal: one"] + stats["Primal: zero"] == 10
|
||||||
|
assert stats["Lower bound"] == stats["Warm start value"]
|
||||||
|
|||||||
Reference in New Issue
Block a user