mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Objective: Add tests
This commit is contained in:
@@ -8,11 +8,12 @@ from unittest.mock import Mock
|
||||
import numpy as np
|
||||
from numpy.testing import assert_array_equal
|
||||
|
||||
from miplearn import GurobiPyomoSolver, LearningSolver
|
||||
from miplearn.instance import Instance
|
||||
from miplearn.classifiers import Regressor
|
||||
from miplearn.components.objective import ObjectiveValueComponent
|
||||
from miplearn.types import TrainingSample, Features
|
||||
from tests.fixtures.knapsack import get_test_pyomo_instances
|
||||
from tests.fixtures.knapsack import get_test_pyomo_instances, get_knapsack_instance
|
||||
|
||||
|
||||
def test_x_y_predict() -> None:
|
||||
@@ -151,3 +152,13 @@ def test_xy_sample_without_lp() -> None:
|
||||
x_actual, y_actual = xy
|
||||
assert x_actual == x_expected
|
||||
assert y_actual == y_expected
|
||||
|
||||
|
||||
def test_usage():
|
||||
solver = LearningSolver(components=[ObjectiveValueComponent()])
|
||||
instance = get_knapsack_instance(GurobiPyomoSolver())
|
||||
solver.solve(instance)
|
||||
solver.fit([instance])
|
||||
stats = solver.solve(instance)
|
||||
assert stats["Lower bound"] == stats["Objective: predicted LB"]
|
||||
assert stats["Upper bound"] == stats["Objective: predicted UB"]
|
||||
|
||||
Reference in New Issue
Block a user