Add pre argument to sample_xy

This commit is contained in:
2021-04-13 19:19:49 -05:00
parent a01c179341
commit bec7dae6d9
8 changed files with 64 additions and 33 deletions

View File

@@ -159,6 +159,7 @@ class Component(EnforceOverrides):
self,
instance: Optional[Instance],
sample: Sample,
pre: Optional[List[Any]] = None,
) -> Tuple[Dict, Dict]:
"""
Returns a pair of x and y dictionaries containing, respectively, the matrices
@@ -175,7 +176,7 @@ class Component(EnforceOverrides):
) -> None:
return
def pre_sample_xy(self, instance: Instance, sample: Sample) -> None:
def pre_sample_xy(self, instance: Instance, sample: Sample) -> Any:
pass
@staticmethod
@@ -183,28 +184,29 @@ class Component(EnforceOverrides):
components: Dict[str, "Component"],
instances: List[Instance],
) -> None:
x_combined: Dict = {}
y_combined: Dict = {}
for (cname, comp) in components.items():
x_combined[cname] = {}
y_combined[cname] = {}
# pre_sample_xy
for instance in instances:
def _pre_sample_xy(instance: Instance) -> Dict:
pre_instance: Dict = {}
for (cname, comp) in components.items():
pre_instance[cname] = []
instance.load()
for sample in instance.samples:
for (cname, comp) in components.items():
comp.pre_sample_xy(instance, sample)
pre_instance[cname].append(comp.pre_sample_xy(instance, sample))
instance.free()
return pre_instance
# sample_xy
for instance in instances:
def _sample_xy(instance: Instance, pre: Dict) -> Tuple[Dict, Dict]:
x_instance: Dict = {}
y_instance: Dict = {}
for (cname, comp) in components.items():
x_instance[cname] = {}
y_instance[cname] = {}
instance.load()
for sample in instance.samples:
for (cname, comp) in components.items():
x = x_combined[cname]
y = y_combined[cname]
x_sample, y_sample = comp.sample_xy(instance, sample)
x = x_instance[cname]
y = y_instance[cname]
x_sample, y_sample = comp.sample_xy(instance, sample, pre[cname])
for cat in x_sample.keys():
if cat not in x:
x[cat] = []
@@ -212,12 +214,29 @@ class Component(EnforceOverrides):
x[cat] += x_sample[cat]
y[cat] += y_sample[cat]
instance.free()
return x_instance, y_instance
# fit_xy
pre = [_pre_sample_xy(instance) for instance in instances]
pre_combined: Dict = {}
for (cname, comp) in components.items():
x = x_combined[cname]
y = y_combined[cname]
for cat in x.keys():
x[cat] = np.array(x[cat], dtype=np.float32)
y[cat] = np.array(y[cat])
comp.fit_xy(x, y)
pre_combined[cname] = []
for p in pre:
pre_combined[cname].extend(p[cname])
xy = [_sample_xy(instances, pre_combined) for instances in instances]
for (cname, comp) in components.items():
x_comp: Dict = {}
y_comp: Dict = {}
for (x, y) in xy:
for cat in x[cname].keys():
if cat not in x_comp:
x_comp[cat] = []
y_comp[cat] = []
x_comp[cat].extend(x[cname][cat])
y_comp[cat].extend(y[cname][cat])
for cat in x_comp.keys():
x_comp[cat] = np.array(x_comp[cat], dtype=np.float32)
y_comp[cat] = np.array(y_comp[cat])
comp.fit_xy(x_comp, y_comp)