Replace Hashable by str

This commit is contained in:
2021-07-15 16:21:40 -05:00
parent 8d89285cb9
commit ef9c48d79a
21 changed files with 123 additions and 133 deletions

View File

@@ -3,7 +3,7 @@
# Released under the modified BSD license. See COPYING.md for more details.
import logging
from typing import Any, FrozenSet, Hashable, List
from typing import Any, FrozenSet, List
import gurobipy as gp
import networkx as nx
@@ -40,13 +40,13 @@ class GurobiStableSetProblem(Instance):
return True
@overrides
def find_violated_user_cuts(self, model: Any) -> List[FrozenSet]:
def find_violated_user_cuts(self, model: Any) -> List[str]:
assert isinstance(model, gp.Model)
vals = model.cbGetNodeRel(model.getVars())
violations = []
for clique in nx.find_cliques(self.graph):
if sum(vals[i] for i in clique) > 1:
violations += [frozenset(clique)]
violations.append(",".join([str(i) for i in clique]))
return violations
@overrides
@@ -54,11 +54,11 @@ class GurobiStableSetProblem(Instance):
self,
solver: InternalSolver,
model: Any,
cid: Hashable,
cid: str,
) -> Any:
assert isinstance(cid, FrozenSet)
clique = [int(i) for i in cid.split(",")]
x = model.getVars()
model.addConstr(gp.quicksum([x[i] for i in cid]) <= 1)
model.addConstr(gp.quicksum([x[i] for i in clique]) <= 1)
@pytest.fixture

View File

@@ -1,7 +1,7 @@
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details.
from typing import Hashable, Dict
from typing import Dict
from unittest.mock import Mock
import numpy as np
@@ -44,11 +44,11 @@ def test_sample_xy(sample: Sample) -> None:
def test_fit_xy() -> None:
x: Dict[Hashable, np.ndarray] = {
x: Dict[str, np.ndarray] = {
"Lower bound": np.array([[0.0, 0.0], [1.0, 2.0]]),
"Upper bound": np.array([[0.0, 0.0], [1.0, 2.0]]),
}
y: Dict[Hashable, np.ndarray] = {
y: Dict[str, np.ndarray] = {
"Lower bound": np.array([[100.0]]),
"Upper bound": np.array([[200.0]]),
}

View File

@@ -121,8 +121,8 @@ def test_evaluate(sample: Sample) -> None:
assert_equals(
ev,
{
0: classifier_evaluation_dict(tp=0, fp=1, tn=1, fn=2),
1: classifier_evaluation_dict(tp=1, fp=1, tn=1, fn=1),
"0": classifier_evaluation_dict(tp=0, fp=1, tn=1, fn=2),
"1": classifier_evaluation_dict(tp=1, fp=1, tn=1, fn=1),
},
)

View File

@@ -1,7 +1,7 @@
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details.
from typing import Dict, cast, Hashable
from typing import Dict, cast
from unittest.mock import Mock, call
import numpy as np
@@ -175,14 +175,14 @@ def test_sample_predict(sample: Sample) -> None:
def test_fit_xy() -> None:
x = cast(
Dict[Hashable, np.ndarray],
Dict[str, np.ndarray],
{
"type-a": np.array([[1.0, 1.0], [1.0, 2.0], [1.0, 3.0]]),
"type-b": np.array([[1.0, 4.0, 0.0]]),
},
)
y = cast(
Dict[Hashable, np.ndarray],
Dict[str, np.ndarray],
{
"type-a": np.array([[False, True], [False, True], [True, False]]),
"type-b": np.array([[False, True]]),