From 088d679f61cbe9b6485df14fbb16ee4d7ef7bee5 Mon Sep 17 00:00:00 2001 From: "Alinson S. Xavier" Date: Sat, 10 Apr 2021 15:46:53 -0500 Subject: [PATCH] Redesign InternalSolver constraint methods --- miplearn/components/dynamic_lazy.py | 1 + miplearn/components/dynamic_user_cuts.py | 7 +- miplearn/components/static_lazy.py | 24 ++-- miplearn/instance/base.py | 8 +- miplearn/instance/picklegz.py | 9 +- miplearn/problems/tsp.py | 7 +- miplearn/solvers/gurobi.py | 128 +++++++-------------- miplearn/solvers/internal.py | 24 ++-- miplearn/solvers/pyomo/base.py | 90 +++++++-------- miplearn/solvers/tests/__init__.py | 50 +++----- tests/components/test_dynamic_user_cuts.py | 10 +- tests/components/test_static_lazy.py | 13 ++- 12 files changed, 155 insertions(+), 216 deletions(-) diff --git a/miplearn/components/dynamic_lazy.py b/miplearn/components/dynamic_lazy.py index 7b7af20..b1e8ca0 100644 --- a/miplearn/components/dynamic_lazy.py +++ b/miplearn/components/dynamic_lazy.py @@ -76,6 +76,7 @@ class DynamicLazyConstraintsComponent(Component): instance: Instance, model: Any, ) -> bool: + assert solver.internal_solver is not None logger.debug("Finding violated lazy constraints...") cids = instance.find_violated_lazy_constraints(solver.internal_solver, model) if len(cids) == 0: diff --git a/miplearn/components/dynamic_user_cuts.py b/miplearn/components/dynamic_user_cuts.py index 802ea67..a1e409b 100644 --- a/miplearn/components/dynamic_user_cuts.py +++ b/miplearn/components/dynamic_user_cuts.py @@ -53,8 +53,7 @@ class UserCutsComponent(Component): cids = self.dynamic.sample_predict(instance, training_data) logger.info("Enforcing %d user cuts ahead-of-time..." % len(cids)) for cid in cids: - cobj = instance.build_user_cut(model, cid) - solver.internal_solver.add_constraint(cobj) + instance.enforce_user_cut(solver.internal_solver, model, cid) stats["UserCuts: Added ahead-of-time"] = len(cids) @overrides @@ -73,9 +72,7 @@ class UserCutsComponent(Component): if cid in self.enforced: continue assert isinstance(cid, Hashable) - cobj = instance.build_user_cut(model, cid) - assert cobj is not None - solver.internal_solver.add_cut(cobj) + instance.enforce_user_cut(solver.internal_solver, model, cid) self.enforced.add(cid) self.n_added_in_callback += 1 if len(cids) > 0: diff --git a/miplearn/components/static_lazy.py b/miplearn/components/static_lazy.py index 1a196d8..69a713d 100644 --- a/miplearn/components/static_lazy.py +++ b/miplearn/components/static_lazy.py @@ -12,7 +12,7 @@ from miplearn.classifiers import Classifier from miplearn.classifiers.counting import CountingClassifier from miplearn.classifiers.threshold import MinProbabilityThreshold, Threshold from miplearn.components.component import Component -from miplearn.features import TrainingSample, Features +from miplearn.features import TrainingSample, Features, Constraint from miplearn.types import LearningSolveStats logger = logging.getLogger(__name__) @@ -44,7 +44,7 @@ class StaticLazyConstraintsComponent(Component): self.threshold_prototype: Threshold = threshold self.classifiers: Dict[Hashable, Classifier] = {} self.thresholds: Dict[Hashable, Threshold] = {} - self.pool: Dict[str, LazyConstraint] = {} + self.pool: Dict[str, Constraint] = {} self.violation_tolerance: float = violation_tolerance self.enforced_cids: Set[Hashable] = set() self.n_restored: int = 0 @@ -72,10 +72,8 @@ class StaticLazyConstraintsComponent(Component): self.pool = {} for (cid, cdict) in features.constraints.items(): if cdict.lazy and cid not in self.enforced_cids: - self.pool[cid] = LazyConstraint( - cid=cid, - obj=solver.internal_solver.extract_constraint(cid), - ) + self.pool[cid] = cdict + solver.internal_solver.remove_constraint(cid) logger.info( f"{len(self.enforced_cids)} lazy constraints kept; " f"{len(self.pool)} moved to the pool" @@ -124,18 +122,18 @@ class StaticLazyConstraintsComponent(Component): def _check_and_add(self, solver: "LearningSolver") -> bool: assert solver.internal_solver is not None logger.info("Finding violated lazy constraints...") - enforced: List[LazyConstraint] = [] + enforced: Dict[str, Constraint] = {} for (cid, c) in self.pool.items(): if not solver.internal_solver.is_constraint_satisfied( - c.obj, + c, tol=self.violation_tolerance, ): - enforced.append(c) + enforced[cid] = c logger.info(f"{len(enforced)} violations found") - for c in enforced: - del self.pool[c.cid] - solver.internal_solver.add_constraint(c.obj) - self.enforced_cids.add(c.cid) + for (cid, c) in enforced.items(): + del self.pool[cid] + solver.internal_solver.add_constraint(c, name=cid) + self.enforced_cids.add(cid) self.n_restored += 1 logger.info( f"{len(enforced)} constraints restored; {len(self.pool)} in the pool" diff --git a/miplearn/instance/base.py b/miplearn/instance/base.py index ec1cfd4..2466928 100644 --- a/miplearn/instance/base.py +++ b/miplearn/instance/base.py @@ -16,6 +16,7 @@ logger = logging.getLogger(__name__) if TYPE_CHECKING: from miplearn.solvers.learning import InternalSolver + # noinspection PyMethodMayBeStatic class Instance(ABC, EnforceOverrides): """ @@ -170,7 +171,12 @@ class Instance(ABC, EnforceOverrides): def find_violated_user_cuts(self, model: Any) -> List[Hashable]: return [] - def build_user_cut(self, model: Any, violation: Hashable) -> Any: + def enforce_user_cut( + self, + solver: "InternalSolver", + model: Any, + violation: Hashable, + ) -> Any: return None def load(self) -> None: diff --git a/miplearn/instance/picklegz.py b/miplearn/instance/picklegz.py index d6599b5..7d165e7 100644 --- a/miplearn/instance/picklegz.py +++ b/miplearn/instance/picklegz.py @@ -106,9 +106,14 @@ class PickleGzInstance(Instance): return self.instance.find_violated_user_cuts(model) @overrides - def build_user_cut(self, model: Any, violation: Hashable) -> None: + def enforce_user_cut( + self, + solver: "InternalSolver", + model: Any, + violation: Hashable, + ) -> None: assert self.instance is not None - self.instance.build_user_cut(model, violation) + self.instance.enforce_user_cut(solver, model, violation) @overrides def load(self) -> None: diff --git a/miplearn/problems/tsp.py b/miplearn/problems/tsp.py index b8cea97..545d595 100644 --- a/miplearn/problems/tsp.py +++ b/miplearn/problems/tsp.py @@ -11,7 +11,7 @@ from scipy.spatial.distance import pdist, squareform from scipy.stats import uniform, randint from scipy.stats.distributions import rv_frozen -from miplearn import InternalSolver +from miplearn import InternalSolver, BasePyomoSolver from miplearn.instance.base import Instance from miplearn.types import VariableName, Category @@ -108,14 +108,15 @@ class TravelingSalesmanInstance(Instance): model: Any, component: FrozenSet, ) -> None: + assert isinstance(solver, BasePyomoSolver) cut_edges = [ e for e in self.edges if (e[0] in component and e[1] not in component) or (e[0] not in component and e[1] in component) ] - constr = model.eq_subtour.add(sum(model.x[e] for e in cut_edges) >= 2) - solver.add_constraint(constr) + constr = model.eq_subtour.add(expr=sum(model.x[e] for e in cut_edges) >= 2) + solver.add_constraint(constr, name="") class TravelingSalesmanGenerator: diff --git a/miplearn/solvers/gurobi.py b/miplearn/solvers/gurobi.py index 82396b1..0eaa149 100644 --- a/miplearn/solvers/gurobi.py +++ b/miplearn/solvers/gurobi.py @@ -32,14 +32,6 @@ from miplearn.types import ( logger = logging.getLogger(__name__) -@dataclass -class ExtractedGurobiConstraint: - lhs: Any - rhs: float - sense: str - name: str - - class GurobiSolver(InternalSolver): """ An InternalSolver backed by Gurobi's Python API (without Pyomo). @@ -72,10 +64,10 @@ class GurobiSolver(InternalSolver): self.instance: Optional[Instance] = None self.model: Optional["gurobipy.Model"] = None self.params: SolverParams = params - self.varname_to_var: Dict[str, "gurobipy.Var"] = {} - self.bin_vars: List["gurobipy.Var"] = [] self.cb_where: Optional[int] = None self.lazy_cb_frequency = lazy_cb_frequency + self._bin_vars: List["gurobipy.Var"] = [] + self._varname_to_var: Dict[str, "gurobipy.Var"] = {} if self.lazy_cb_frequency == 1: self.lazy_cb_where = [self.gp.GRB.Callback.MIPSOL] @@ -106,20 +98,20 @@ class GurobiSolver(InternalSolver): def _update_vars(self) -> None: assert self.model is not None - self.varname_to_var.clear() - self.bin_vars.clear() + self._varname_to_var.clear() + self._bin_vars.clear() for var in self.model.getVars(): - assert var.varName not in self.varname_to_var, ( + assert var.varName not in self._varname_to_var, ( f"Duplicated variable name detected: {var.varName}. " f"Unique variable names are currently required." ) - self.varname_to_var[var.varName] = var + self._varname_to_var[var.varName] = var assert var.vtype in ["B", "C"], ( "Only binary and continuous variables are currently supported. " "Variable {var.varName} has type {var.vtype}." ) if var.vtype == "B": - self.bin_vars.append(var) + self._bin_vars.append(var) def _apply_params(self, streams: List[Any]) -> None: assert self.model is not None @@ -138,13 +130,13 @@ class GurobiSolver(InternalSolver): streams += [sys.stdout] self._apply_params(streams) assert self.model is not None - for var in self.bin_vars: + for var in self._bin_vars: var.vtype = self.gp.GRB.CONTINUOUS var.lb = 0.0 var.ub = 1.0 with _RedirectOutput(streams): self.model.optimize() - for var in self.bin_vars: + for var in self._bin_vars: var.vtype = self.gp.GRB.BINARY log = streams[0].getvalue() opt_value = None @@ -262,7 +254,7 @@ class GurobiSolver(InternalSolver): self._raise_if_callback() self._clear_warm_start() for (var_name, value) in solution.items(): - var = self.varname_to_var[var_name] + var = self._varname_to_var[var_name] if value is not None: var.start = value @@ -288,52 +280,54 @@ class GurobiSolver(InternalSolver): else: return c.pi - def _get_value(self, var: Any) -> Optional[float]: + def _get_value(self, var: Any) -> float: assert self.model is not None if self.cb_where == self.gp.GRB.Callback.MIPSOL: return self.model.cbGetSolution(var) elif self.cb_where == self.gp.GRB.Callback.MIPNODE: return self.model.cbGetNodeRel(var) elif self.cb_where is None: - if self.is_infeasible(): - return None - else: - return var.x + return var.x else: raise Exception( "get_value cannot be called from cb_where=%s" % self.cb_where ) @overrides - def add_constraint(self, cobj: Any, name: str = "") -> None: + def add_constraint(self, constr: Constraint, name: str) -> None: assert self.model is not None - if isinstance(cobj, ExtractedGurobiConstraint): - if self.cb_where in [ - self.gp.GRB.Callback.MIPSOL, - self.gp.GRB.Callback.MIPNODE, - ]: - self.model.cbLazy(cobj.lhs, cobj.sense, cobj.rhs) - else: - self.model.addConstr(cobj.lhs, cobj.sense, cobj.rhs, cobj.name) - elif isinstance(cobj, self.gp.TempConstr): - if self.cb_where in [ - self.gp.GRB.Callback.MIPSOL, - self.gp.GRB.Callback.MIPNODE, - ]: - self.model.cbLazy(cobj) - else: - self.model.addConstr(cobj, name=name) + lhs = self.gp.quicksum( + self._varname_to_var[varname] * coeff + for (varname, coeff) in constr.lhs.items() + ) + if constr.sense == "=": + self.model.addConstr(lhs == constr.rhs, name=name) + elif constr.sense == "<": + self.model.addConstr(lhs <= constr.rhs, name=name) else: - raise Exception(f"unknown constraint type: {cobj.__class__.__name__}") + self.model.addConstr(lhs >= constr.rhs, name=name) @overrides - def add_cut(self, cobj: Any) -> None: + def remove_constraint(self, name: str) -> None: assert self.model is not None - assert self.cb_where == self.gp.GRB.Callback.MIPNODE - self.model.cbCut(cobj) + constr = self.model.getConstrByName(name) + self.model.remove(constr) + + @overrides + def is_constraint_satisfied(self, constr: Constraint, tol: float = 1e-6) -> bool: + lhs = 0.0 + for (varname, coeff) in constr.lhs.items(): + var = self._varname_to_var[varname] + lhs += self._get_value(var) * coeff + if constr.sense == "<": + return lhs <= constr.rhs + tol + elif constr.sense == ">": + return lhs >= constr.rhs - tol + else: + return abs(constr.rhs - lhs) < abs(tol) def _clear_warm_start(self) -> None: - for var in self.varname_to_var.values(): + for var in self._varname_to_var.values(): var.start = self.gp.GRB.UNDEFINED @overrides @@ -342,50 +336,11 @@ class GurobiSolver(InternalSolver): for (varname, value) in solution.items(): if value is None: continue - var = self.varname_to_var[varname] + var = self._varname_to_var[varname] var.vtype = self.gp.GRB.CONTINUOUS var.lb = value var.ub = value - @overrides - def extract_constraint(self, cid: str) -> ExtractedGurobiConstraint: - self._raise_if_callback() - assert self.model is not None - constr = self.model.getConstrByName(cid) - cobj = ExtractedGurobiConstraint( - lhs=self.model.getRow(constr), - sense=constr.sense, - rhs=constr.RHS, - name=constr.ConstrName, - ) - self.model.remove(constr) - return cobj - - @overrides - def is_constraint_satisfied( - self, - cobj: ExtractedGurobiConstraint, - tol: float = 1e-6, - ) -> bool: - assert isinstance(cobj, ExtractedGurobiConstraint) - lhs, sense, rhs, _ = cobj.lhs, cobj.sense, cobj.rhs, cobj.name - if self.cb_where is not None: - lhs_value = lhs.getConstant() - for i in range(lhs.size()): - var = lhs.getVar(i) - coeff = lhs.getCoeff(i) - lhs_value += self._get_value(var) * coeff - else: - lhs_value = lhs.getValue() - if sense == "<": - return lhs_value <= rhs + tol - elif sense == ">": - return lhs_value >= rhs - tol - elif sense == "=": - return abs(rhs - lhs_value) < abs(tol) - else: - raise Exception("Unknown sense: %s" % sense) - @overrides def get_inequality_slacks(self) -> Dict[str, float]: assert self.model is not None @@ -545,4 +500,5 @@ class GurobiTestInstanceKnapsack(PyomoTestInstanceKnapsack): model: Any, violation: Hashable, ) -> None: - solver.add_constraint(model.getVarByName("x[0]") <= 0, name="cut") + x0 = model.getVarByName("x[0]") + model.cbLazy(x0 <= 0) diff --git a/miplearn/solvers/internal.py b/miplearn/solvers/internal.py index cd6b800..243b5db 100644 --- a/miplearn/solvers/internal.py +++ b/miplearn/solvers/internal.py @@ -6,6 +6,8 @@ import logging from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional +from overrides import EnforceOverrides + from miplearn.features import Constraint from miplearn.instance.base import Instance from miplearn.types import ( @@ -22,7 +24,7 @@ from miplearn.types import ( logger = logging.getLogger(__name__) -class InternalSolver(ABC): +class InternalSolver(ABC, EnforceOverrides): """ Abstract class representing the MIP solver used internally by LearningSolver. """ @@ -155,31 +157,21 @@ class InternalSolver(ABC): pass @abstractmethod - def add_constraint(self, cobj: Any, name: str = "") -> None: + def add_constraint(self, constr: Constraint, name: str) -> None: """ - Adds a single constraint to the model. + Adds a given constraint to the model. """ pass - def add_cut(self, cobj: Any) -> None: - """ - Adds a cutting plane to the model. This function can only be called from a user - cut callback. - """ - raise NotImplementedError() - @abstractmethod - def extract_constraint(self, cid: str) -> Any: + def remove_constraint(self, name: str) -> None: """ - Removes a given constraint from the model and returns an object `cobj` which - can be used to verify if the removed constraint is still satisfied by - the current solution, using `is_constraint_satisfied(cobj)`, and can potentially - be re-added to the model using `add_constraint(cobj)`. + Removes the constraint that has a given name from the model. """ pass @abstractmethod - def is_constraint_satisfied(self, cobj: Any, tol: float = 1e-6) -> bool: + def is_constraint_satisfied(self, constr: Constraint, tol: float = 1e-6) -> bool: """ Returns True if the current solution satisfies the given constraint. """ diff --git a/miplearn/solvers/pyomo/base.py b/miplearn/solvers/pyomo/base.py index 2ea570f..1934bd8 100644 --- a/miplearn/solvers/pyomo/base.py +++ b/miplearn/solvers/pyomo/base.py @@ -6,14 +6,15 @@ import logging import re import sys from io import StringIO -from typing import Any, List, Dict, Optional, Hashable +from typing import Any, List, Dict, Optional +import numpy as np import pyomo from overrides import overrides from pyomo import environ as pe from pyomo.core import Var from pyomo.core.base import _GeneralVarData -from pyomo.core.base.constraint import SimpleConstraint, ConstraintList +from pyomo.core.base.constraint import ConstraintList from pyomo.core.expr.numeric_expr import SumExpression, MonomialTermExpression from pyomo.opt import TerminationCondition from pyomo.opt.base.solvers import SolverFactory @@ -35,7 +36,6 @@ from miplearn.types import ( VariableName, Category, ) -import numpy as np logger = logging.getLogger(__name__) @@ -215,7 +215,7 @@ class BasePyomoSolver(InternalSolver): def _update_constrs(self) -> None: assert self.model is not None - self._cname_to_constr = {} + self._cname_to_constr.clear() for constr in self.model.component_objects(pyomo.core.Constraint): if isinstance(constr, pe.ConstraintList): for idx in constr: @@ -233,24 +233,50 @@ class BasePyomoSolver(InternalSolver): self._pyomo_solver.update_var(var) @overrides - def add_constraint(self, cobj: Any, name: str = "") -> Any: + def add_constraint( + self, + constr: Any, + name: str, + ) -> None: assert self.model is not None - if isinstance(cobj, Constraint): + if isinstance(constr, Constraint): lhs = 0.0 - for (varname, coeff) in cobj.lhs.items(): + for (varname, coeff) in constr.lhs.items(): var = self._varname_to_var[varname] lhs += var * coeff - if cobj.sense == "=": - expr = lhs == cobj.rhs - elif cobj.sense == "<": - expr = lhs <= cobj.rhs + if constr.sense == "=": + expr = lhs == constr.rhs + elif constr.sense == "<": + expr = lhs <= constr.rhs else: - expr = lhs >= cobj.rhs - cl = self.model.extra_constraints - self._pyomo_solver.add_constraint(cl.add(expr)) + expr = lhs >= constr.rhs + cl = pe.Constraint(expr=expr, name=name) + self.model.add_component(name, cl) + self._pyomo_solver.add_constraint(cl) + self._cname_to_constr[name] = cl else: - self._pyomo_solver.add_constraint(cobj) - self._update_constrs() + self._pyomo_solver.add_constraint(constr) + + @overrides + def remove_constraint(self, name: str) -> None: + assert self.model is not None + constr = self._cname_to_constr[name] + del self._cname_to_constr[name] + self.model.del_component(constr) + self._pyomo_solver.remove_constraint(constr) + + @overrides + def is_constraint_satisfied(self, constr: Constraint, tol: float = 1e-6) -> bool: + lhs = 0.0 + for (varname, coeff) in constr.lhs.items(): + var = self._varname_to_var[varname] + lhs += var.value * coeff + if constr.sense == "<": + return lhs <= constr.rhs + tol + elif constr.sense == ">": + return lhs >= constr.rhs - tol + else: + return abs(constr.rhs - lhs) < abs(tol) @staticmethod def __extract( @@ -304,27 +330,6 @@ class BasePyomoSolver(InternalSolver): result[cname] = cobj.slack() return result - @overrides - def extract_constraint(self, cid: str) -> Any: - cobj = self._cname_to_constr[cid] - constr = self._parse_pyomo_constraint(cobj) - self._pyomo_solver.remove_constraint(cobj) - return constr - - @overrides - def is_constraint_satisfied(self, cobj: Any, tol: float = 1e-6) -> bool: - assert isinstance(cobj, Constraint) - lhs_value = 0.0 - for (varname, coeff) in cobj.lhs.items(): - var = self._varname_to_var[varname] - lhs_value += var.value * coeff - if cobj.sense == "=": - return (lhs_value <= cobj.rhs + tol) and (lhs_value >= cobj.rhs - tol) - elif cobj.sense == "<": - return lhs_value <= cobj.rhs + tol - else: - return lhs_value >= cobj.rhs - tol - @overrides def is_infeasible(self) -> bool: return self._termination_condition == TerminationCondition.infeasible @@ -411,6 +416,7 @@ class BasePyomoSolver(InternalSolver): sense=sense, ) + @overrides def are_callbacks_supported(self) -> bool: return False @@ -483,13 +489,3 @@ class PyomoTestInstanceKnapsack(Instance): self.weights[item], self.prices[item], ] - - @overrides - def enforce_lazy_constraint( - self, - solver: InternalSolver, - model: Any, - violation: Hashable, - ) -> None: - model.cut = pe.Constraint(expr=model.x[0] <= 0.0, name="cut") - solver.add_constraint(model.cut) diff --git a/miplearn/solvers/tests/__init__.py b/miplearn/solvers/tests/__init__.py index c58186b..233c5c8 100644 --- a/miplearn/solvers/tests/__init__.py +++ b/miplearn/solvers/tests/__init__.py @@ -75,39 +75,29 @@ def run_basic_usage_tests(solver: InternalSolver) -> None: solver.get_constraints(), { "eq_capacity": Constraint( - lhs={ - "x[0]": 23.0, - "x[1]": 26.0, - "x[2]": 20.0, - "x[3]": 18.0, - }, + lhs={"x[0]": 23.0, "x[1]": 26.0, "x[2]": 20.0, "x[3]": 18.0}, rhs=67.0, sense="<", ), }, ) - # Add a brand new constraint - instance.enforce_lazy_constraint(solver, model, "cut") + # Build a new constraint + cut = Constraint(lhs={"x[0]": 1.0}, sense="<", rhs=0.0) + assert not solver.is_constraint_satisfied(cut) - # New constraint should be listed + # Add new constraint and verify that it is listed + solver.add_constraint(cut, "cut") assert_equals( solver.get_constraints(), { "eq_capacity": Constraint( - lhs={ - "x[0]": 23.0, - "x[1]": 26.0, - "x[2]": 20.0, - "x[3]": 18.0, - }, + lhs={"x[0]": 23.0, "x[1]": 26.0, "x[2]": 20.0, "x[3]": 18.0}, rhs=67.0, sense="<", ), "cut": Constraint( - lhs={ - "x[0]": 1.0, - }, + lhs={"x[0]": 1.0}, rhs=0.0, sense="<", ), @@ -117,35 +107,23 @@ def run_basic_usage_tests(solver: InternalSolver) -> None: # New constraint should affect the solution stats = solver.solve() assert_equals(stats["Lower bound"], 1030.0) + assert solver.is_constraint_satisfied(cut) # Verify slacks assert_equals( solver.get_inequality_slacks(), - { - "cut": 0.0, - "eq_capacity": 3.0, - }, + {"cut": 0.0, "eq_capacity": 3.0}, ) - # # Extract the new constraint - cobj = solver.extract_constraint("cut") + # Remove the new constraint + solver.remove_constraint("cut") # New constraint should no longer affect solution stats = solver.solve() assert_equals(stats["Lower bound"], 1183.0) - # New constraint should not be satisfied by current solution - assert not solver.is_constraint_satisfied(cobj) - - # Re-add constraint - solver.add_constraint(cobj) - - # Constraint should affect solution again - stats = solver.solve() - assert_equals(stats["Lower bound"], 1030.0) - - # New constraint should now be satisfied - assert solver.is_constraint_satisfied(cobj) + # Constraint should not be satisfied by current solution + assert not solver.is_constraint_satisfied(cut) def run_warm_start_tests(solver: InternalSolver) -> None: diff --git a/tests/components/test_dynamic_user_cuts.py b/tests/components/test_dynamic_user_cuts.py index a1da5fb..199efed 100644 --- a/tests/components/test_dynamic_user_cuts.py +++ b/tests/components/test_dynamic_user_cuts.py @@ -12,6 +12,7 @@ from gurobipy import GRB from networkx import Graph from overrides import overrides +from miplearn import InternalSolver from miplearn.components.dynamic_user_cuts import UserCutsComponent from miplearn.instance.base import Instance from miplearn.solvers.gurobi import GurobiSolver @@ -49,10 +50,15 @@ class GurobiStableSetProblem(Instance): return violations @overrides - def build_user_cut(self, model: Any, cid: Hashable) -> Any: + def enforce_user_cut( + self, + solver: InternalSolver, + model: Any, + cid: Hashable, + ) -> Any: assert isinstance(cid, FrozenSet) x = model.getVars() - return gp.quicksum([x[i] for i in cid]) <= 1 + model.addConstr(gp.quicksum([x[i] for i in cid]) <= 1) @pytest.fixture diff --git a/tests/components/test_static_lazy.py b/tests/components/test_static_lazy.py index be8070a..d8cd25a 100644 --- a/tests/components/test_static_lazy.py +++ b/tests/components/test_static_lazy.py @@ -78,12 +78,14 @@ def features() -> Features: def test_usage_with_solver(instance: Instance) -> None: + assert instance.features is not None + assert instance.features.constraints is not None + solver = Mock(spec=LearningSolver) solver.use_lazy_cb = False solver.gap_tolerance = 1e-4 internal = solver.internal_solver = Mock(spec=InternalSolver) - internal.extract_constraint = Mock(side_effect=lambda cid: "<%s>" % cid) internal.is_constraint_satisfied = Mock(return_value=False) component = StaticLazyConstraintsComponent(violation_tolerance=1.0) @@ -128,8 +130,8 @@ def test_usage_with_solver(instance: Instance) -> None: component.classifiers["type-b"].predict_proba.assert_called_once() # Should ask internal solver to remove some constraints - assert internal.extract_constraint.call_count == 1 - internal.extract_constraint.assert_has_calls([call("c3")]) + assert internal.remove_constraint.call_count == 1 + internal.remove_constraint.assert_has_calls([call("c3")]) # LearningSolver calls after_iteration (first time) should_repeat = component.iteration_cb(solver, instance, None) @@ -137,9 +139,10 @@ def test_usage_with_solver(instance: Instance) -> None: # Should ask internal solver to verify if constraints in the pool are # satisfied and add the ones that are not - internal.is_constraint_satisfied.assert_called_once_with("", tol=1.0) + c3 = instance.features.constraints["c3"] + internal.is_constraint_satisfied.assert_called_once_with(c3, tol=1.0) internal.is_constraint_satisfied.reset_mock() - internal.add_constraint.assert_called_once_with("") + internal.add_constraint.assert_called_once_with(c3, name="c3") internal.add_constraint.reset_mock() # LearningSolver calls after_iteration (second time)