LearningSolver: Load each instance exactly twice during fit

This commit is contained in:
2021-04-13 18:11:37 -05:00
parent ef7a50e871
commit a01c179341
7 changed files with 116 additions and 208 deletions

View File

@@ -99,16 +99,6 @@ class Component(EnforceOverrides):
"""
return
def fit(
self,
training_instances: List[Instance],
) -> None:
x, y = self.xy_instances(training_instances)
for cat in x.keys():
x[cat] = np.array(x[cat], dtype=np.float32)
y[cat] = np.array(y[cat])
self.fit_xy(x, y)
def fit_xy(
self,
x: Dict[Hashable, np.ndarray],
@@ -185,21 +175,49 @@ class Component(EnforceOverrides):
) -> None:
return
def xy_instances(
self,
def pre_sample_xy(self, instance: Instance, sample: Sample) -> None:
pass
@staticmethod
def fit_multiple(
components: Dict[str, "Component"],
instances: List[Instance],
) -> Tuple[Dict, Dict]:
) -> None:
x_combined: Dict = {}
y_combined: Dict = {}
for (cname, comp) in components.items():
x_combined[cname] = {}
y_combined[cname] = {}
# pre_sample_xy
for instance in instances:
instance.load()
for sample in instance.samples:
x_sample, y_sample = self.sample_xy(instance, sample)
for cat in x_sample.keys():
if cat not in x_combined:
x_combined[cat] = []
y_combined[cat] = []
x_combined[cat] += x_sample[cat]
y_combined[cat] += y_sample[cat]
for (cname, comp) in components.items():
comp.pre_sample_xy(instance, sample)
instance.free()
return x_combined, y_combined
# sample_xy
for instance in instances:
instance.load()
for sample in instance.samples:
for (cname, comp) in components.items():
x = x_combined[cname]
y = y_combined[cname]
x_sample, y_sample = comp.sample_xy(instance, sample)
for cat in x_sample.keys():
if cat not in x:
x[cat] = []
y[cat] = []
x[cat] += x_sample[cat]
y[cat] += y_sample[cat]
instance.free()
# fit_xy
for (cname, comp) in components.items():
x = x_combined[cname]
y = y_combined[cname]
for cat in x.keys():
x[cat] = np.array(x[cat], dtype=np.float32)
y[cat] = np.array(y[cat])
comp.fit_xy(x, y)

View File

@@ -117,22 +117,14 @@ class DynamicConstraintsComponent(Component):
return pred
@overrides
def fit(self, training_instances: List[Instance]) -> None:
collected_cids = set()
for instance in training_instances:
instance.load()
for sample in instance.samples:
if (
sample.after_mip is None
or sample.after_mip.extra is None
or sample.after_mip.extra[self.attr] is None
):
continue
collected_cids |= sample.after_mip.extra[self.attr]
instance.free()
self.known_cids.clear()
self.known_cids.extend(sorted(collected_cids))
super().fit(training_instances)
def pre_sample_xy(self, instance: Instance, sample: Sample) -> None:
if (
sample.after_mip is None
or sample.after_mip.extra is None
or sample.after_mip.extra[self.attr] is None
):
return
self.known_cids.extend(sorted(sample.after_mip.extra[self.attr]))
@overrides
def fit_xy(

View File

@@ -119,8 +119,8 @@ class DynamicLazyConstraintsComponent(Component):
return self.dynamic.sample_predict(instance, sample)
@overrides
def fit(self, training_instances: List[Instance]) -> None:
self.dynamic.fit(training_instances)
def pre_sample_xy(self, instance: Instance, sample: Sample) -> None:
self.dynamic.pre_sample_xy(instance, sample)
@overrides
def fit_xy(

View File

@@ -112,8 +112,8 @@ class UserCutsComponent(Component):
return self.dynamic.sample_predict(instance, sample)
@overrides
def fit(self, training_instances: List["Instance"]) -> None:
self.dynamic.fit(training_instances)
def pre_sample_xy(self, instance: Instance, sample: Sample) -> None:
self.dynamic.pre_sample_xy(instance, sample)
@overrides
def fit_xy(