mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-08 02:18:51 -06:00
Convert TrainingSample to dataclass
This commit is contained in:
@@ -155,8 +155,8 @@ class PrimalSolutionComponent(Component):
|
||||
x: Dict = {}
|
||||
y: Dict = {}
|
||||
solution: Optional[Solution] = None
|
||||
if "Solution" in sample and sample["Solution"] is not None:
|
||||
solution = sample["Solution"]
|
||||
if sample.solution is not None:
|
||||
solution = sample.solution
|
||||
for (var_name, var_dict) in features.variables.items():
|
||||
for (idx, var_features) in var_dict.items():
|
||||
category = var_features.category
|
||||
@@ -168,8 +168,8 @@ class PrimalSolutionComponent(Component):
|
||||
f: List[float] = []
|
||||
assert var_features.user_features is not None
|
||||
f += var_features.user_features
|
||||
if "LP solution" in sample and sample["LP solution"] is not None:
|
||||
lp_value = sample["LP solution"][var_name][idx]
|
||||
if sample.lp_solution is not None:
|
||||
lp_value = sample.lp_solution[var_name][idx]
|
||||
if lp_value is not None:
|
||||
f += [lp_value]
|
||||
x[category] += [f]
|
||||
@@ -190,7 +190,7 @@ class PrimalSolutionComponent(Component):
|
||||
features: Features,
|
||||
sample: TrainingSample,
|
||||
) -> Dict[Hashable, Dict[str, float]]:
|
||||
solution_actual = sample["Solution"]
|
||||
solution_actual = sample.solution
|
||||
assert solution_actual is not None
|
||||
solution_pred = self.sample_predict(features, sample)
|
||||
vars_all, vars_one, vars_zero = set(), set(), set()
|
||||
|
||||
Reference in New Issue
Block a user