mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-10 11:28:51 -06:00
Convert TrainingSample to dataclass
This commit is contained in:
@@ -82,12 +82,14 @@ class ObjectiveValueComponent(Component):
|
||||
x: Dict[Hashable, List[List[float]]] = {}
|
||||
y: Dict[Hashable, List[List[float]]] = {}
|
||||
f = list(features.instance.user_features)
|
||||
if "LP value" in sample and sample["LP value"] is not None:
|
||||
f += [sample["LP value"]]
|
||||
for c in ["Upper bound", "Lower bound"]:
|
||||
x[c] = [f]
|
||||
if c in sample and sample[c] is not None: # type: ignore
|
||||
y[c] = [[sample[c]]] # type: ignore
|
||||
if sample.lp_value is not None:
|
||||
f += [sample.lp_value]
|
||||
x["Upper bound"] = [f]
|
||||
x["Lower bound"] = [f]
|
||||
if sample.lower_bound is not None:
|
||||
y["Lower bound"] = [[sample.lower_bound]]
|
||||
if sample.upper_bound is not None:
|
||||
y["Upper bound"] = [[sample.upper_bound]]
|
||||
return x, y
|
||||
|
||||
def sample_evaluate(
|
||||
@@ -106,7 +108,8 @@ class ObjectiveValueComponent(Component):
|
||||
|
||||
result: Dict[Hashable, Dict[str, float]] = {}
|
||||
pred = self.sample_predict(features, sample)
|
||||
for c in ["Upper bound", "Lower bound"]:
|
||||
if c in sample and sample[c] is not None: # type: ignore
|
||||
result[c] = compare(pred[c], sample[c]) # type: ignore
|
||||
if sample.upper_bound is not None:
|
||||
result["Upper bound"] = compare(pred["Upper bound"], sample.upper_bound)
|
||||
if sample.lower_bound is not None:
|
||||
result["Lower bound"] = compare(pred["Lower bound"], sample.lower_bound)
|
||||
return result
|
||||
|
||||
Reference in New Issue
Block a user