mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
LearningSolver: Load each instance exactly twice during fit
This commit is contained in:
@@ -99,16 +99,6 @@ class Component(EnforceOverrides):
|
||||
"""
|
||||
return
|
||||
|
||||
def fit(
|
||||
self,
|
||||
training_instances: List[Instance],
|
||||
) -> None:
|
||||
x, y = self.xy_instances(training_instances)
|
||||
for cat in x.keys():
|
||||
x[cat] = np.array(x[cat], dtype=np.float32)
|
||||
y[cat] = np.array(y[cat])
|
||||
self.fit_xy(x, y)
|
||||
|
||||
def fit_xy(
|
||||
self,
|
||||
x: Dict[Hashable, np.ndarray],
|
||||
@@ -185,21 +175,49 @@ class Component(EnforceOverrides):
|
||||
) -> None:
|
||||
return
|
||||
|
||||
def xy_instances(
|
||||
self,
|
||||
def pre_sample_xy(self, instance: Instance, sample: Sample) -> None:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def fit_multiple(
|
||||
components: Dict[str, "Component"],
|
||||
instances: List[Instance],
|
||||
) -> Tuple[Dict, Dict]:
|
||||
) -> None:
|
||||
x_combined: Dict = {}
|
||||
y_combined: Dict = {}
|
||||
for (cname, comp) in components.items():
|
||||
x_combined[cname] = {}
|
||||
y_combined[cname] = {}
|
||||
|
||||
# pre_sample_xy
|
||||
for instance in instances:
|
||||
instance.load()
|
||||
for sample in instance.samples:
|
||||
x_sample, y_sample = self.sample_xy(instance, sample)
|
||||
for cat in x_sample.keys():
|
||||
if cat not in x_combined:
|
||||
x_combined[cat] = []
|
||||
y_combined[cat] = []
|
||||
x_combined[cat] += x_sample[cat]
|
||||
y_combined[cat] += y_sample[cat]
|
||||
for (cname, comp) in components.items():
|
||||
comp.pre_sample_xy(instance, sample)
|
||||
instance.free()
|
||||
return x_combined, y_combined
|
||||
|
||||
# sample_xy
|
||||
for instance in instances:
|
||||
instance.load()
|
||||
for sample in instance.samples:
|
||||
for (cname, comp) in components.items():
|
||||
x = x_combined[cname]
|
||||
y = y_combined[cname]
|
||||
x_sample, y_sample = comp.sample_xy(instance, sample)
|
||||
for cat in x_sample.keys():
|
||||
if cat not in x:
|
||||
x[cat] = []
|
||||
y[cat] = []
|
||||
x[cat] += x_sample[cat]
|
||||
y[cat] += y_sample[cat]
|
||||
instance.free()
|
||||
|
||||
# fit_xy
|
||||
for (cname, comp) in components.items():
|
||||
x = x_combined[cname]
|
||||
y = y_combined[cname]
|
||||
for cat in x.keys():
|
||||
x[cat] = np.array(x[cat], dtype=np.float32)
|
||||
y[cat] = np.array(y[cat])
|
||||
comp.fit_xy(x, y)
|
||||
|
||||
@@ -117,22 +117,14 @@ class DynamicConstraintsComponent(Component):
|
||||
return pred
|
||||
|
||||
@overrides
|
||||
def fit(self, training_instances: List[Instance]) -> None:
|
||||
collected_cids = set()
|
||||
for instance in training_instances:
|
||||
instance.load()
|
||||
for sample in instance.samples:
|
||||
if (
|
||||
sample.after_mip is None
|
||||
or sample.after_mip.extra is None
|
||||
or sample.after_mip.extra[self.attr] is None
|
||||
):
|
||||
continue
|
||||
collected_cids |= sample.after_mip.extra[self.attr]
|
||||
instance.free()
|
||||
self.known_cids.clear()
|
||||
self.known_cids.extend(sorted(collected_cids))
|
||||
super().fit(training_instances)
|
||||
def pre_sample_xy(self, instance: Instance, sample: Sample) -> None:
|
||||
if (
|
||||
sample.after_mip is None
|
||||
or sample.after_mip.extra is None
|
||||
or sample.after_mip.extra[self.attr] is None
|
||||
):
|
||||
return
|
||||
self.known_cids.extend(sorted(sample.after_mip.extra[self.attr]))
|
||||
|
||||
@overrides
|
||||
def fit_xy(
|
||||
|
||||
@@ -119,8 +119,8 @@ class DynamicLazyConstraintsComponent(Component):
|
||||
return self.dynamic.sample_predict(instance, sample)
|
||||
|
||||
@overrides
|
||||
def fit(self, training_instances: List[Instance]) -> None:
|
||||
self.dynamic.fit(training_instances)
|
||||
def pre_sample_xy(self, instance: Instance, sample: Sample) -> None:
|
||||
self.dynamic.pre_sample_xy(instance, sample)
|
||||
|
||||
@overrides
|
||||
def fit_xy(
|
||||
|
||||
@@ -112,8 +112,8 @@ class UserCutsComponent(Component):
|
||||
return self.dynamic.sample_predict(instance, sample)
|
||||
|
||||
@overrides
|
||||
def fit(self, training_instances: List["Instance"]) -> None:
|
||||
self.dynamic.fit(training_instances)
|
||||
def pre_sample_xy(self, instance: Instance, sample: Sample) -> None:
|
||||
self.dynamic.pre_sample_xy(instance, sample)
|
||||
|
||||
@overrides
|
||||
def fit_xy(
|
||||
|
||||
@@ -325,7 +325,6 @@ class LearningSolver:
|
||||
instance=instance,
|
||||
model=model,
|
||||
tee=tee,
|
||||
discard_output=True,
|
||||
)
|
||||
self.fit([instance])
|
||||
instance.instance = None
|
||||
@@ -396,9 +395,7 @@ class LearningSolver:
|
||||
if len(training_instances) == 0:
|
||||
logger.warning("Empty list of training instances provided. Skipping.")
|
||||
return
|
||||
for component in self.components.values():
|
||||
logger.info(f"Fitting {component.__class__.__name__}...")
|
||||
component.fit(training_instances)
|
||||
Component.fit_multiple(self.components, training_instances)
|
||||
|
||||
def _add_component(self, component: Component) -> None:
|
||||
name = component.__class__.__name__
|
||||
|
||||
Reference in New Issue
Block a user