Replace Hashable by str

This commit is contained in:
2021-07-15 16:21:40 -05:00
parent 8d89285cb9
commit ef9c48d79a
21 changed files with 123 additions and 133 deletions

View File

@@ -3,15 +3,7 @@
# Released under the modified BSD license. See COPYING.md for more details.
import logging
from typing import (
Dict,
List,
Hashable,
Any,
TYPE_CHECKING,
Tuple,
Optional,
)
from typing import Dict, List, Any, TYPE_CHECKING, Tuple, Optional
import numpy as np
from overrides import overrides
@@ -55,8 +47,8 @@ class PrimalSolutionComponent(Component):
assert isinstance(threshold, Threshold)
assert mode in ["exact", "heuristic"]
self.mode = mode
self.classifiers: Dict[Hashable, Classifier] = {}
self.thresholds: Dict[Hashable, Threshold] = {}
self.classifiers: Dict[str, Classifier] = {}
self.thresholds: Dict[str, Threshold] = {}
self.threshold_prototype = threshold
self.classifier_prototype = classifier
@@ -128,7 +120,7 @@ class PrimalSolutionComponent(Component):
# Convert y_pred into solution
solution: Solution = {v: None for v in var_names}
category_offset: Dict[Hashable, int] = {cat: 0 for cat in x.keys()}
category_offset: Dict[str, int] = {cat: 0 for cat in x.keys()}
for (i, var_name) in enumerate(var_names):
category = var_categories[i]
if category not in category_offset:
@@ -194,7 +186,7 @@ class PrimalSolutionComponent(Component):
self,
_: Optional[Instance],
sample: Sample,
) -> Dict[Hashable, Dict[str, float]]:
) -> Dict[str, Dict[str, float]]:
mip_var_values = sample.get("mip_var_values")
var_names = sample.get("var_names")
assert mip_var_values is not None
@@ -221,13 +213,13 @@ class PrimalSolutionComponent(Component):
pred_one_negative = vars_all - pred_one_positive
pred_zero_negative = vars_all - pred_zero_positive
return {
0: classifier_evaluation_dict(
"0": classifier_evaluation_dict(
tp=len(pred_zero_positive & vars_zero),
tn=len(pred_zero_negative & vars_one),
fp=len(pred_zero_positive & vars_one),
fn=len(pred_zero_negative & vars_zero),
),
1: classifier_evaluation_dict(
"1": classifier_evaluation_dict(
tp=len(pred_one_positive & vars_one),
tn=len(pred_one_negative & vars_zero),
fp=len(pred_one_positive & vars_zero),
@@ -238,8 +230,8 @@ class PrimalSolutionComponent(Component):
@overrides
def fit_xy(
self,
x: Dict[Hashable, np.ndarray],
y: Dict[Hashable, np.ndarray],
x: Dict[str, np.ndarray],
y: Dict[str, np.ndarray],
) -> None:
for category in x.keys():
clf = self.classifier_prototype.clone()