mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-08 02:18:51 -06:00
Replace Hashable by str
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
import logging
|
||||
from typing import Dict, Tuple, List, Hashable, Any, TYPE_CHECKING, Set, Optional
|
||||
from typing import Dict, Tuple, List, Any, TYPE_CHECKING, Set, Optional
|
||||
|
||||
import numpy as np
|
||||
from overrides import overrides
|
||||
@@ -44,11 +44,11 @@ class StaticLazyConstraintsComponent(Component):
|
||||
assert isinstance(classifier, Classifier)
|
||||
self.classifier_prototype: Classifier = classifier
|
||||
self.threshold_prototype: Threshold = threshold
|
||||
self.classifiers: Dict[Hashable, Classifier] = {}
|
||||
self.thresholds: Dict[Hashable, Threshold] = {}
|
||||
self.classifiers: Dict[str, Classifier] = {}
|
||||
self.thresholds: Dict[str, Threshold] = {}
|
||||
self.pool: Constraints = Constraints()
|
||||
self.violation_tolerance: float = violation_tolerance
|
||||
self.enforced_cids: Set[Hashable] = set()
|
||||
self.enforced_cids: Set[str] = set()
|
||||
self.n_restored: int = 0
|
||||
self.n_iterations: int = 0
|
||||
|
||||
@@ -105,8 +105,8 @@ class StaticLazyConstraintsComponent(Component):
|
||||
@overrides
|
||||
def fit_xy(
|
||||
self,
|
||||
x: Dict[Hashable, np.ndarray],
|
||||
y: Dict[Hashable, np.ndarray],
|
||||
x: Dict[str, np.ndarray],
|
||||
y: Dict[str, np.ndarray],
|
||||
) -> None:
|
||||
for c in y.keys():
|
||||
assert c in x
|
||||
@@ -136,9 +136,9 @@ class StaticLazyConstraintsComponent(Component):
|
||||
) -> None:
|
||||
self._check_and_add(solver)
|
||||
|
||||
def sample_predict(self, sample: Sample) -> List[Hashable]:
|
||||
def sample_predict(self, sample: Sample) -> List[str]:
|
||||
x, y, cids = self._sample_xy_with_cids(sample)
|
||||
enforced_cids: List[Hashable] = []
|
||||
enforced_cids: List[str] = []
|
||||
for category in x.keys():
|
||||
if category not in self.classifiers:
|
||||
continue
|
||||
@@ -156,7 +156,7 @@ class StaticLazyConstraintsComponent(Component):
|
||||
self,
|
||||
_: Optional[Instance],
|
||||
sample: Sample,
|
||||
) -> Tuple[Dict[Hashable, List[List[float]]], Dict[Hashable, List[List[float]]]]:
|
||||
) -> Tuple[Dict[str, List[List[float]]], Dict[str, List[List[float]]]]:
|
||||
x, y, __ = self._sample_xy_with_cids(sample)
|
||||
return x, y
|
||||
|
||||
@@ -197,13 +197,13 @@ class StaticLazyConstraintsComponent(Component):
|
||||
def _sample_xy_with_cids(
|
||||
self, sample: Sample
|
||||
) -> Tuple[
|
||||
Dict[Hashable, List[List[float]]],
|
||||
Dict[Hashable, List[List[float]]],
|
||||
Dict[Hashable, List[str]],
|
||||
Dict[str, List[List[float]]],
|
||||
Dict[str, List[List[float]]],
|
||||
Dict[str, List[str]],
|
||||
]:
|
||||
x: Dict[Hashable, List[List[float]]] = {}
|
||||
y: Dict[Hashable, List[List[float]]] = {}
|
||||
cids: Dict[Hashable, List[str]] = {}
|
||||
x: Dict[str, List[List[float]]] = {}
|
||||
y: Dict[str, List[List[float]]] = {}
|
||||
cids: Dict[str, List[str]] = {}
|
||||
instance_features = sample.get("instance_features_user")
|
||||
constr_features = sample.get("lp_constr_features")
|
||||
constr_names = sample.get("constr_names")
|
||||
|
||||
Reference in New Issue
Block a user