Convert Features into dataclass

This commit is contained in:
2021-04-04 22:37:16 -05:00
parent f2520f33fb
commit 59f4f75a53
12 changed files with 77 additions and 67 deletions

View File

@@ -105,9 +105,11 @@ class PrimalSolutionComponent(Component):
features: Features,
sample: TrainingSample,
) -> Solution:
assert features.variables is not None
# Initialize empty solution
solution: Solution = {}
for (var_name, var_dict) in features["Variables"].items():
for (var_name, var_dict) in features.variables.items():
solution[var_name] = {}
for idx in var_dict.keys():
solution[var_name][idx] = None
@@ -132,7 +134,7 @@ class PrimalSolutionComponent(Component):
# Convert y_pred into solution
category_offset: Dict[Hashable, int] = {cat: 0 for cat in x.keys()}
for (var_name, var_dict) in features["Variables"].items():
for (var_name, var_dict) in features.variables.items():
for (idx, var_features) in var_dict.items():
category = var_features["Category"]
offset = category_offset[category]
@@ -149,12 +151,13 @@ class PrimalSolutionComponent(Component):
features: Features,
sample: TrainingSample,
) -> Tuple[Dict[Hashable, List[List[float]]], Dict[Hashable, List[List[float]]]]:
assert features.variables is not None
x: Dict = {}
y: Dict = {}
solution: Optional[Solution] = None
if "Solution" in sample and sample["Solution"] is not None:
solution = sample["Solution"]
for (var_name, var_dict) in features["Variables"].items():
for (var_name, var_dict) in features.variables.items():
for (idx, var_features) in var_dict.items():
category = var_features["Category"]
if category is None: