mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Redesign InternalSolver constraint methods
This commit is contained in:
@@ -12,6 +12,7 @@ from gurobipy import GRB
|
||||
from networkx import Graph
|
||||
from overrides import overrides
|
||||
|
||||
from miplearn import InternalSolver
|
||||
from miplearn.components.dynamic_user_cuts import UserCutsComponent
|
||||
from miplearn.instance.base import Instance
|
||||
from miplearn.solvers.gurobi import GurobiSolver
|
||||
@@ -49,10 +50,15 @@ class GurobiStableSetProblem(Instance):
|
||||
return violations
|
||||
|
||||
@overrides
|
||||
def build_user_cut(self, model: Any, cid: Hashable) -> Any:
|
||||
def enforce_user_cut(
|
||||
self,
|
||||
solver: InternalSolver,
|
||||
model: Any,
|
||||
cid: Hashable,
|
||||
) -> Any:
|
||||
assert isinstance(cid, FrozenSet)
|
||||
x = model.getVars()
|
||||
return gp.quicksum([x[i] for i in cid]) <= 1
|
||||
model.addConstr(gp.quicksum([x[i] for i in cid]) <= 1)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -78,12 +78,14 @@ def features() -> Features:
|
||||
|
||||
|
||||
def test_usage_with_solver(instance: Instance) -> None:
|
||||
assert instance.features is not None
|
||||
assert instance.features.constraints is not None
|
||||
|
||||
solver = Mock(spec=LearningSolver)
|
||||
solver.use_lazy_cb = False
|
||||
solver.gap_tolerance = 1e-4
|
||||
|
||||
internal = solver.internal_solver = Mock(spec=InternalSolver)
|
||||
internal.extract_constraint = Mock(side_effect=lambda cid: "<%s>" % cid)
|
||||
internal.is_constraint_satisfied = Mock(return_value=False)
|
||||
|
||||
component = StaticLazyConstraintsComponent(violation_tolerance=1.0)
|
||||
@@ -128,8 +130,8 @@ def test_usage_with_solver(instance: Instance) -> None:
|
||||
component.classifiers["type-b"].predict_proba.assert_called_once()
|
||||
|
||||
# Should ask internal solver to remove some constraints
|
||||
assert internal.extract_constraint.call_count == 1
|
||||
internal.extract_constraint.assert_has_calls([call("c3")])
|
||||
assert internal.remove_constraint.call_count == 1
|
||||
internal.remove_constraint.assert_has_calls([call("c3")])
|
||||
|
||||
# LearningSolver calls after_iteration (first time)
|
||||
should_repeat = component.iteration_cb(solver, instance, None)
|
||||
@@ -137,9 +139,10 @@ def test_usage_with_solver(instance: Instance) -> None:
|
||||
|
||||
# Should ask internal solver to verify if constraints in the pool are
|
||||
# satisfied and add the ones that are not
|
||||
internal.is_constraint_satisfied.assert_called_once_with("<c3>", tol=1.0)
|
||||
c3 = instance.features.constraints["c3"]
|
||||
internal.is_constraint_satisfied.assert_called_once_with(c3, tol=1.0)
|
||||
internal.is_constraint_satisfied.reset_mock()
|
||||
internal.add_constraint.assert_called_once_with("<c3>")
|
||||
internal.add_constraint.assert_called_once_with(c3, name="c3")
|
||||
internal.add_constraint.reset_mock()
|
||||
|
||||
# LearningSolver calls after_iteration (second time)
|
||||
|
||||
Reference in New Issue
Block a user