Convert Features into dataclass

This commit is contained in:
2021-04-04 22:37:16 -05:00
parent f2520f33fb
commit 59f4f75a53
12 changed files with 77 additions and 67 deletions

View File

@@ -61,13 +61,16 @@ class StaticLazyConstraintsComponent(Component):
training_data: TrainingSample,
) -> None:
assert solver.internal_solver is not None
if not features["Instance"]["Lazy constraint count"] == 0:
assert features.instance is not None
assert features.constraints is not None
if not features.instance["Lazy constraint count"] == 0:
logger.info("Instance does not have static lazy constraints. Skipping.")
logger.info("Predicting required lazy constraints...")
self.enforced_cids = set(self.sample_predict(features, training_data))
logger.info("Moving lazy constraints to the pool...")
self.pool = {}
for (cid, cdict) in features["Constraints"].items():
for (cid, cdict) in features.constraints.items():
if cdict["Lazy"] and cid not in self.enforced_cids:
self.pool[cid] = LazyConstraint(
cid=cid,
@@ -145,9 +148,11 @@ class StaticLazyConstraintsComponent(Component):
features: Features,
sample: TrainingSample,
) -> List[str]:
assert features.constraints is not None
x, y = self.sample_xy(features, sample)
category_to_cids: Dict[Hashable, List[str]] = {}
for (cid, cdict) in features["Constraints"].items():
for (cid, cdict) in features.constraints.items():
if "Category" not in cdict or cdict["Category"] is None:
continue
category = cdict["Category"]
@@ -172,9 +177,10 @@ class StaticLazyConstraintsComponent(Component):
features: Features,
sample: TrainingSample,
) -> Tuple[Dict[Hashable, List[List[float]]], Dict[Hashable, List[List[float]]]]:
assert features.constraints is not None
x: Dict = {}
y: Dict = {}
for (cid, cfeatures) in features["Constraints"].items():
for (cid, cfeatures) in features.constraints.items():
if not cfeatures["Lazy"]:
continue
category = cfeatures["Category"]

View File

@@ -77,9 +77,10 @@ class ObjectiveValueComponent(Component):
features: Features,
sample: TrainingSample,
) -> Tuple[Dict[Hashable, List[List[float]]], Dict[Hashable, List[List[float]]]]:
assert features.instance is not None
x: Dict[Hashable, List[List[float]]] = {}
y: Dict[Hashable, List[List[float]]] = {}
f = list(features["Instance"]["User features"])
f = list(features.instance["User features"])
if "LP value" in sample and sample["LP value"] is not None:
f += [sample["LP value"]]
for c in ["Upper bound", "Lower bound"]:

View File

@@ -105,9 +105,11 @@ class PrimalSolutionComponent(Component):
features: Features,
sample: TrainingSample,
) -> Solution:
assert features.variables is not None
# Initialize empty solution
solution: Solution = {}
for (var_name, var_dict) in features["Variables"].items():
for (var_name, var_dict) in features.variables.items():
solution[var_name] = {}
for idx in var_dict.keys():
solution[var_name][idx] = None
@@ -132,7 +134,7 @@ class PrimalSolutionComponent(Component):
# Convert y_pred into solution
category_offset: Dict[Hashable, int] = {cat: 0 for cat in x.keys()}
for (var_name, var_dict) in features["Variables"].items():
for (var_name, var_dict) in features.variables.items():
for (idx, var_features) in var_dict.items():
category = var_features["Category"]
offset = category_offset[category]
@@ -149,12 +151,13 @@ class PrimalSolutionComponent(Component):
features: Features,
sample: TrainingSample,
) -> Tuple[Dict[Hashable, List[List[float]]], Dict[Hashable, List[List[float]]]]:
assert features.variables is not None
x: Dict = {}
y: Dict = {}
solution: Optional[Solution] = None
if "Solution" in sample and sample["Solution"] is not None:
solution = sample["Solution"]
for (var_name, var_dict) in features["Variables"].items():
for (var_name, var_dict) in features.variables.items():
for (idx, var_features) in var_dict.items():
category = var_features["Category"]
if category is None: