Replace Hashable by str

This commit is contained in:
2021-07-15 16:21:40 -05:00
parent 8d89285cb9
commit ef9c48d79a
21 changed files with 123 additions and 133 deletions

View File

@@ -3,7 +3,7 @@
# Released under the modified BSD license. See COPYING.md for more details.
import logging
from typing import List, Dict, Any, TYPE_CHECKING, Tuple, Hashable, Optional
from typing import List, Dict, Any, TYPE_CHECKING, Tuple, Optional
import numpy as np
from overrides import overrides
@@ -53,8 +53,8 @@ class ObjectiveValueComponent(Component):
@overrides
def fit_xy(
self,
x: Dict[Hashable, np.ndarray],
y: Dict[Hashable, np.ndarray],
x: Dict[str, np.ndarray],
y: Dict[str, np.ndarray],
) -> None:
for c in ["Upper bound", "Lower bound"]:
if c in y:
@@ -76,20 +76,20 @@ class ObjectiveValueComponent(Component):
self,
_: Optional[Instance],
sample: Sample,
) -> Tuple[Dict[Hashable, List[List[float]]], Dict[Hashable, List[List[float]]]]:
) -> Tuple[Dict[str, List[List[float]]], Dict[str, List[List[float]]]]:
lp_instance_features = sample.get("lp_instance_features")
if lp_instance_features is None:
lp_instance_features = sample.get("instance_features_user")
assert lp_instance_features is not None
# Features
x: Dict[Hashable, List[List[float]]] = {
x: Dict[str, List[List[float]]] = {
"Upper bound": [lp_instance_features],
"Lower bound": [lp_instance_features],
}
# Labels
y: Dict[Hashable, List[List[float]]] = {}
y: Dict[str, List[List[float]]] = {}
mip_lower_bound = sample.get("mip_lower_bound")
mip_upper_bound = sample.get("mip_upper_bound")
if mip_lower_bound is not None:
@@ -104,7 +104,7 @@ class ObjectiveValueComponent(Component):
self,
instance: Instance,
sample: Sample,
) -> Dict[Hashable, Dict[str, float]]:
) -> Dict[str, Dict[str, float]]:
def compare(y_pred: float, y_actual: float) -> Dict[str, float]:
err = np.round(abs(y_pred - y_actual), 8)
return {
@@ -114,7 +114,7 @@ class ObjectiveValueComponent(Component):
"Relative error": err / y_actual,
}
result: Dict[Hashable, Dict[str, float]] = {}
result: Dict[str, Dict[str, float]] = {}
pred = self.sample_predict(sample)
actual_ub = sample.get("mip_upper_bound")
actual_lb = sample.get("mip_lower_bound")