Rename xy_sample to xy

This commit is contained in:
2021-04-02 06:26:48 -05:00
parent bc8fe4dc98
commit ef556f94f0
8 changed files with 37 additions and 64 deletions

View File

@@ -134,7 +134,7 @@ class PrimalSolutionComponent(Component):
solution[var_name][idx] = None
# Compute y_pred
x = self.x_sample(features, sample)
x, _ = self.xy(features, sample)
y_pred = {}
for category in x.keys():
assert category in self.classifiers, (
@@ -173,7 +173,7 @@ class PrimalSolutionComponent(Component):
):
instance = instances[instance_idx]
solution_actual = instance.training_data[0]["Solution"]
solution_pred = self.predict(instance)
solution_pred = self.predict(instance, instance.training_data[0])
vars_all, vars_one, vars_zero = set(), set(), set()
pred_one_positive, pred_zero_positive = set(), set()
@@ -213,33 +213,10 @@ class PrimalSolutionComponent(Component):
return ev
@staticmethod
def xy_sample(
def xy(
features: Features,
sample: TrainingSample,
) -> Optional[Tuple[Dict, Dict]]:
if "Solution" not in sample:
return None
assert sample["Solution"] is not None
return cast(
Tuple[Dict, Dict],
PrimalSolutionComponent._extract(features, sample),
)
@staticmethod
def x_sample(
features: Features,
sample: TrainingSample,
) -> Dict:
return cast(
Dict,
PrimalSolutionComponent._extract(features, sample),
)
@staticmethod
def _extract(
features: Features,
sample: TrainingSample,
) -> Union[Dict, Tuple[Dict, Dict]]:
) -> Tuple[Dict, Dict]:
x: Dict = {}
y: Dict = {}
solution: Optional[Solution] = None
@@ -271,7 +248,4 @@ class PrimalSolutionComponent(Component):
"category to None."
)
y[category] += [[opt_value < 0.5, opt_value >= 0.5]]
if solution is not None:
return x, y
else:
return x
return x, y