Replace Hashable by str

This commit is contained in:
2021-07-15 16:21:40 -05:00
parent 8d89285cb9
commit ef9c48d79a
21 changed files with 123 additions and 133 deletions

View File

@@ -3,7 +3,7 @@
# Released under the modified BSD license. See COPYING.md for more details.
import logging
from typing import Dict, Tuple, List, Hashable, Any, TYPE_CHECKING, Set, Optional
from typing import Dict, Tuple, List, Any, TYPE_CHECKING, Set, Optional
import numpy as np
from overrides import overrides
@@ -44,11 +44,11 @@ class StaticLazyConstraintsComponent(Component):
assert isinstance(classifier, Classifier)
self.classifier_prototype: Classifier = classifier
self.threshold_prototype: Threshold = threshold
self.classifiers: Dict[Hashable, Classifier] = {}
self.thresholds: Dict[Hashable, Threshold] = {}
self.classifiers: Dict[str, Classifier] = {}
self.thresholds: Dict[str, Threshold] = {}
self.pool: Constraints = Constraints()
self.violation_tolerance: float = violation_tolerance
self.enforced_cids: Set[Hashable] = set()
self.enforced_cids: Set[str] = set()
self.n_restored: int = 0
self.n_iterations: int = 0
@@ -105,8 +105,8 @@ class StaticLazyConstraintsComponent(Component):
@overrides
def fit_xy(
self,
x: Dict[Hashable, np.ndarray],
y: Dict[Hashable, np.ndarray],
x: Dict[str, np.ndarray],
y: Dict[str, np.ndarray],
) -> None:
for c in y.keys():
assert c in x
@@ -136,9 +136,9 @@ class StaticLazyConstraintsComponent(Component):
) -> None:
self._check_and_add(solver)
def sample_predict(self, sample: Sample) -> List[Hashable]:
def sample_predict(self, sample: Sample) -> List[str]:
x, y, cids = self._sample_xy_with_cids(sample)
enforced_cids: List[Hashable] = []
enforced_cids: List[str] = []
for category in x.keys():
if category not in self.classifiers:
continue
@@ -156,7 +156,7 @@ class StaticLazyConstraintsComponent(Component):
self,
_: Optional[Instance],
sample: Sample,
) -> Tuple[Dict[Hashable, List[List[float]]], Dict[Hashable, List[List[float]]]]:
) -> Tuple[Dict[str, List[List[float]]], Dict[str, List[List[float]]]]:
x, y, __ = self._sample_xy_with_cids(sample)
return x, y
@@ -197,13 +197,13 @@ class StaticLazyConstraintsComponent(Component):
def _sample_xy_with_cids(
self, sample: Sample
) -> Tuple[
Dict[Hashable, List[List[float]]],
Dict[Hashable, List[List[float]]],
Dict[Hashable, List[str]],
Dict[str, List[List[float]]],
Dict[str, List[List[float]]],
Dict[str, List[str]],
]:
x: Dict[Hashable, List[List[float]]] = {}
y: Dict[Hashable, List[List[float]]] = {}
cids: Dict[Hashable, List[str]] = {}
x: Dict[str, List[List[float]]] = {}
y: Dict[str, List[List[float]]] = {}
cids: Dict[str, List[str]] = {}
instance_features = sample.get("instance_features_user")
constr_features = sample.get("lp_constr_features")
constr_names = sample.get("constr_names")