mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Allow user to attach arbitrary data to violations
This commit is contained in:
@@ -10,6 +10,7 @@ import pytest
|
||||
from miplearn.classifiers import Classifier
|
||||
from miplearn.classifiers.threshold import MinProbabilityThreshold
|
||||
from miplearn.components import classifier_evaluation_dict
|
||||
from miplearn.components.dynamic_common import DynamicConstraintsComponent
|
||||
from miplearn.components.dynamic_lazy import DynamicLazyConstraintsComponent
|
||||
from miplearn.features.sample import MemorySample
|
||||
from miplearn.instance.base import Instance
|
||||
@@ -24,13 +25,23 @@ def training_instances() -> List[Instance]:
|
||||
samples_0 = [
|
||||
MemorySample(
|
||||
{
|
||||
"mip_constr_lazy_enforced": np.array(["c1", "c2"], dtype="S"),
|
||||
"mip_constr_lazy": DynamicConstraintsComponent.encode(
|
||||
{
|
||||
b"c1": 0,
|
||||
b"c2": 0,
|
||||
}
|
||||
),
|
||||
"static_instance_features": np.array([5.0]),
|
||||
},
|
||||
),
|
||||
MemorySample(
|
||||
{
|
||||
"mip_constr_lazy_enforced": np.array(["c2", "c3"], dtype="S"),
|
||||
"mip_constr_lazy": DynamicConstraintsComponent.encode(
|
||||
{
|
||||
b"c2": 0,
|
||||
b"c3": 0,
|
||||
}
|
||||
),
|
||||
"static_instance_features": np.array([5.0]),
|
||||
},
|
||||
),
|
||||
@@ -55,7 +66,12 @@ def training_instances() -> List[Instance]:
|
||||
samples_1 = [
|
||||
MemorySample(
|
||||
{
|
||||
"mip_constr_lazy_enforced": np.array(["c3", "c4"], dtype="S"),
|
||||
"mip_constr_lazy": DynamicConstraintsComponent.encode(
|
||||
{
|
||||
b"c3": 0,
|
||||
b"c4": 0,
|
||||
}
|
||||
),
|
||||
"static_instance_features": np.array([8.0]),
|
||||
},
|
||||
)
|
||||
@@ -83,8 +99,8 @@ def test_sample_xy(training_instances: List[Instance]) -> None:
|
||||
comp = DynamicLazyConstraintsComponent()
|
||||
comp.pre_fit(
|
||||
[
|
||||
np.array(["c1", "c3", "c4"], dtype="S"),
|
||||
np.array(["c1", "c2", "c4"], dtype="S"),
|
||||
{b"c1": 0, b"c3": 0, b"c4": 0},
|
||||
{b"c1": 0, b"c2": 0, b"c4": 0},
|
||||
]
|
||||
)
|
||||
x_expected = {
|
||||
@@ -105,7 +121,10 @@ def test_sample_xy(training_instances: List[Instance]) -> None:
|
||||
|
||||
def test_sample_predict_evaluate(training_instances: List[Instance]) -> None:
|
||||
comp = DynamicLazyConstraintsComponent()
|
||||
comp.known_cids.extend([b"c1", b"c2", b"c3", b"c4"])
|
||||
comp.known_violations[b"c1"] = 0
|
||||
comp.known_violations[b"c2"] = 0
|
||||
comp.known_violations[b"c3"] = 0
|
||||
comp.known_violations[b"c4"] = 0
|
||||
comp.thresholds[b"type-a"] = MinProbabilityThreshold([0.5, 0.5])
|
||||
comp.thresholds[b"type-b"] = MinProbabilityThreshold([0.5, 0.5])
|
||||
comp.classifiers[b"type-a"] = Mock(spec=Classifier)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
||||
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, FrozenSet, List
|
||||
from typing import Any, List, Dict
|
||||
|
||||
import gurobipy as gp
|
||||
import networkx as nx
|
||||
@@ -12,12 +12,11 @@ from gurobipy import GRB
|
||||
from networkx import Graph
|
||||
from overrides import overrides
|
||||
|
||||
from miplearn.solvers.learning import InternalSolver
|
||||
from miplearn.components.dynamic_user_cuts import UserCutsComponent
|
||||
from miplearn.instance.base import Instance
|
||||
from miplearn.solvers.gurobi import GurobiSolver
|
||||
from miplearn.solvers.learning import LearningSolver
|
||||
from miplearn.types import ConstraintName, ConstraintCategory
|
||||
from miplearn.types import ConstraintName
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -41,13 +40,14 @@ class GurobiStableSetProblem(Instance):
|
||||
return True
|
||||
|
||||
@overrides
|
||||
def find_violated_user_cuts(self, model: Any) -> List[ConstraintName]:
|
||||
def find_violated_user_cuts(self, model: Any) -> Dict[ConstraintName, Any]:
|
||||
assert isinstance(model, gp.Model)
|
||||
vals = model.cbGetNodeRel(model.getVars())
|
||||
violations = []
|
||||
violations = {}
|
||||
for clique in nx.find_cliques(self.graph):
|
||||
if sum(vals[i] for i in clique) > 1:
|
||||
violations.append(",".join([str(i) for i in clique]).encode())
|
||||
vname = (",".join([str(i) for i in clique])).encode()
|
||||
violations[vname] = list(clique)
|
||||
return violations
|
||||
|
||||
@overrides
|
||||
@@ -55,9 +55,8 @@ class GurobiStableSetProblem(Instance):
|
||||
self,
|
||||
solver: GurobiSolver,
|
||||
model: Any,
|
||||
cid: ConstraintName,
|
||||
clique: List[int],
|
||||
) -> Any:
|
||||
clique = [int(i) for i in cid.decode().split(",")]
|
||||
x = model.getVars()
|
||||
constr = gp.quicksum([x[i] for i in clique]) <= 1
|
||||
if solver.cb_where:
|
||||
@@ -86,9 +85,11 @@ def test_usage(
|
||||
) -> None:
|
||||
stats_before = solver.solve(stab_instance)
|
||||
sample = stab_instance.get_samples()[0]
|
||||
user_cuts_enforced = sample.get_array("mip_user_cuts_enforced")
|
||||
assert user_cuts_enforced is not None
|
||||
assert len(user_cuts_enforced) > 0
|
||||
user_cuts_encoded = sample.get_scalar("mip_user_cuts")
|
||||
assert user_cuts_encoded is not None
|
||||
user_cuts = json.loads(user_cuts_encoded)
|
||||
assert user_cuts is not None
|
||||
assert len(user_cuts) > 0
|
||||
assert stats_before["UserCuts: Added ahead-of-time"] == 0
|
||||
assert stats_before["UserCuts: Added in callback"] > 0
|
||||
|
||||
|
||||
Reference in New Issue
Block a user