Convert TrainingSample to dataclass

This commit is contained in:
2021-04-05 20:36:04 -05:00
parent aeed338837
commit b11779817a
15 changed files with 122 additions and 129 deletions

View File

@@ -93,7 +93,7 @@ class StaticLazyConstraintsComponent(Component):
features: Features,
training_data: TrainingSample,
) -> None:
training_data["LazyStatic: Enforced"] = self.enforced_cids
training_data.lazy_enforced = self.enforced_cids
stats["LazyStatic: Restored"] = self.n_restored
stats["LazyStatic: Iterations"] = self.n_iterations
@@ -188,8 +188,8 @@ class StaticLazyConstraintsComponent(Component):
x[category] = []
y[category] = []
x[category] += [cfeatures.user_features]
if "LazyStatic: Enforced" in sample:
if cid in sample["LazyStatic: Enforced"]:
if sample.lazy_enforced is not None:
if cid in sample.lazy_enforced:
y[category] += [[False, True]]
else:
y[category] += [[True, False]]

View File

@@ -82,12 +82,14 @@ class ObjectiveValueComponent(Component):
x: Dict[Hashable, List[List[float]]] = {}
y: Dict[Hashable, List[List[float]]] = {}
f = list(features.instance.user_features)
if "LP value" in sample and sample["LP value"] is not None:
f += [sample["LP value"]]
for c in ["Upper bound", "Lower bound"]:
x[c] = [f]
if c in sample and sample[c] is not None: # type: ignore
y[c] = [[sample[c]]] # type: ignore
if sample.lp_value is not None:
f += [sample.lp_value]
x["Upper bound"] = [f]
x["Lower bound"] = [f]
if sample.lower_bound is not None:
y["Lower bound"] = [[sample.lower_bound]]
if sample.upper_bound is not None:
y["Upper bound"] = [[sample.upper_bound]]
return x, y
def sample_evaluate(
@@ -106,7 +108,8 @@ class ObjectiveValueComponent(Component):
result: Dict[Hashable, Dict[str, float]] = {}
pred = self.sample_predict(features, sample)
for c in ["Upper bound", "Lower bound"]:
if c in sample and sample[c] is not None: # type: ignore
result[c] = compare(pred[c], sample[c]) # type: ignore
if sample.upper_bound is not None:
result["Upper bound"] = compare(pred["Upper bound"], sample.upper_bound)
if sample.lower_bound is not None:
result["Lower bound"] = compare(pred["Lower bound"], sample.lower_bound)
return result

View File

@@ -155,8 +155,8 @@ class PrimalSolutionComponent(Component):
x: Dict = {}
y: Dict = {}
solution: Optional[Solution] = None
if "Solution" in sample and sample["Solution"] is not None:
solution = sample["Solution"]
if sample.solution is not None:
solution = sample.solution
for (var_name, var_dict) in features.variables.items():
for (idx, var_features) in var_dict.items():
category = var_features.category
@@ -168,8 +168,8 @@ class PrimalSolutionComponent(Component):
f: List[float] = []
assert var_features.user_features is not None
f += var_features.user_features
if "LP solution" in sample and sample["LP solution"] is not None:
lp_value = sample["LP solution"][var_name][idx]
if sample.lp_solution is not None:
lp_value = sample.lp_solution[var_name][idx]
if lp_value is not None:
f += [lp_value]
x[category] += [f]
@@ -190,7 +190,7 @@ class PrimalSolutionComponent(Component):
features: Features,
sample: TrainingSample,
) -> Dict[Hashable, Dict[str, float]]:
solution_actual = sample["Solution"]
solution_actual = sample.solution
assert solution_actual is not None
solution_pred = self.sample_predict(features, sample)
vars_all, vars_one, vars_zero = set(), set(), set()

View File

@@ -95,8 +95,8 @@ class ConvertTightIneqsIntoEqsStep(Component):
features,
training_data,
):
if "slacks" not in training_data.keys():
training_data["slacks"] = solver.internal_solver.get_inequality_slacks()
if training_data.slacks is None:
training_data.slacks = solver.internal_solver.get_inequality_slacks()
stats["ConvertTight: Restored"] = self.n_restored
stats["ConvertTight: Inf iterations"] = self.n_infeasible_iterations
stats["ConvertTight: Subopt iterations"] = self.n_suboptimal_iterations
@@ -120,7 +120,7 @@ class ConvertTightIneqsIntoEqsStep(Component):
disable=len(instances) < 5,
):
for training_data in instance.training_data:
cids = training_data["slacks"].keys()
cids = training_data.slacks.keys()
for cid in cids:
category = instance.get_constraint_category(cid)
if category is None:
@@ -142,7 +142,7 @@ class ConvertTightIneqsIntoEqsStep(Component):
desc="Extract (rlx:conv_ineqs:y)",
disable=len(instances) < 5,
):
for (cid, slack) in instance.training_data[0]["slacks"].items():
for (cid, slack) in instance.training_data[0].slacks.items():
category = instance.get_constraint_category(cid)
if category is None:
continue

View File

@@ -96,8 +96,8 @@ class DropRedundantInequalitiesStep(Component):
features,
training_data,
):
if "slacks" not in training_data.keys():
training_data["slacks"] = solver.internal_solver.get_inequality_slacks()
if training_data.slacks is None:
training_data.slacks = solver.internal_solver.get_inequality_slacks()
stats["DropRedundant: Iterations"] = self.n_iterations
stats["DropRedundant: Restored"] = self.n_restored
@@ -131,7 +131,7 @@ class DropRedundantInequalitiesStep(Component):
x = {}
y = {}
for training_data in instance.training_data:
for (cid, slack) in training_data["slacks"].items():
for (cid, slack) in training_data.slacks.items():
category = instance.get_constraint_category(cid)
if category is None:
continue