Rename xy_sample to xy

This commit is contained in:
2021-04-02 06:26:48 -05:00
parent bc8fe4dc98
commit ef556f94f0
8 changed files with 37 additions and 64 deletions

View File

@@ -132,17 +132,17 @@ class Component:
return
@staticmethod
def xy_sample(
def xy(
features: Features,
sample: TrainingSample,
) -> Optional[Tuple[Dict, Dict]]:
) -> Tuple[Dict, Dict]:
"""
Given a set of features and a training sample, returns a pair of x and y
dictionaries containing, respectively, the matrices of ML features and the
labels for the sample. If the training sample does not include label
information, returns None.
information, returns (x, {}).
"""
return None
pass
def xy_instances(
self,
@@ -153,7 +153,7 @@ class Component:
for instance in InstanceIterator(instances):
assert isinstance(instance, Instance)
for sample in instance.training_data:
xy = self.xy_sample(instance.features, sample)
xy = self.xy(instance.features, sample)
if xy is None:
continue
x_sample, y_sample = xy

View File

@@ -207,12 +207,10 @@ class StaticLazyConstraintsComponent(Component):
return result
@staticmethod
def xy_sample(
def xy(
features: Features,
sample: TrainingSample,
) -> Optional[Tuple[Dict, Dict]]:
if "LazyStatic: Enforced" not in sample:
return None
) -> Tuple[Dict, Dict]:
x: Dict = {}
y: Dict = {}
for (cid, cfeatures) in features["Constraints"].items():
@@ -225,8 +223,9 @@ class StaticLazyConstraintsComponent(Component):
x[category] = []
y[category] = []
x[category] += [cfeatures["User features"]]
if cid in sample["LazyStatic: Enforced"]:
y[category] += [[False, True]]
else:
y[category] += [[True, False]]
if "LazyStatic: Enforced" in sample:
if cid in sample["LazyStatic: Enforced"]:
y[category] += [[False, True]]
else:
y[category] += [[True, False]]
return x, y

View File

@@ -166,12 +166,10 @@ class ObjectiveValueComponent(Component):
return ev
@staticmethod
def xy_sample(
def xy(
features: Features,
sample: TrainingSample,
) -> Optional[Tuple[Dict, Dict]]:
if "Lower bound" not in sample:
return None
) -> Tuple[Dict, Dict]:
f = features["Instance"]["User features"]
if "LP value" in sample and sample["LP value"] is not None:
f += [sample["LP value"]]
@@ -179,8 +177,11 @@ class ObjectiveValueComponent(Component):
"Lower bound": [f],
"Upper bound": [f],
}
y = {
"Lower bound": [[sample["Lower bound"]]],
"Upper bound": [[sample["Upper bound"]]],
}
return x, y
if "Lower bound" in sample:
y = {
"Lower bound": [[sample["Lower bound"]]],
"Upper bound": [[sample["Upper bound"]]],
}
return x, y
else:
return x, {}

View File

@@ -134,7 +134,7 @@ class PrimalSolutionComponent(Component):
solution[var_name][idx] = None
# Compute y_pred
x = self.x_sample(features, sample)
x, _ = self.xy(features, sample)
y_pred = {}
for category in x.keys():
assert category in self.classifiers, (
@@ -173,7 +173,7 @@ class PrimalSolutionComponent(Component):
):
instance = instances[instance_idx]
solution_actual = instance.training_data[0]["Solution"]
solution_pred = self.predict(instance)
solution_pred = self.predict(instance, instance.training_data[0])
vars_all, vars_one, vars_zero = set(), set(), set()
pred_one_positive, pred_zero_positive = set(), set()
@@ -213,33 +213,10 @@ class PrimalSolutionComponent(Component):
return ev
@staticmethod
def xy_sample(
def xy(
features: Features,
sample: TrainingSample,
) -> Optional[Tuple[Dict, Dict]]:
if "Solution" not in sample:
return None
assert sample["Solution"] is not None
return cast(
Tuple[Dict, Dict],
PrimalSolutionComponent._extract(features, sample),
)
@staticmethod
def x_sample(
features: Features,
sample: TrainingSample,
) -> Dict:
return cast(
Dict,
PrimalSolutionComponent._extract(features, sample),
)
@staticmethod
def _extract(
features: Features,
sample: TrainingSample,
) -> Union[Dict, Tuple[Dict, Dict]]:
) -> Tuple[Dict, Dict]:
x: Dict = {}
y: Dict = {}
solution: Optional[Solution] = None
@@ -271,7 +248,4 @@ class PrimalSolutionComponent(Component):
"category to None."
)
y[category] += [[opt_value < 0.5, opt_value >= 0.5]]
if solution is not None:
return x, y
else:
return x
return x, y