mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-07 18:08:51 -06:00
LearningSolver: Load each instance exactly twice during fit
This commit is contained in:
@@ -117,22 +117,14 @@ class DynamicConstraintsComponent(Component):
|
||||
return pred
|
||||
|
||||
@overrides
|
||||
def fit(self, training_instances: List[Instance]) -> None:
|
||||
collected_cids = set()
|
||||
for instance in training_instances:
|
||||
instance.load()
|
||||
for sample in instance.samples:
|
||||
if (
|
||||
sample.after_mip is None
|
||||
or sample.after_mip.extra is None
|
||||
or sample.after_mip.extra[self.attr] is None
|
||||
):
|
||||
continue
|
||||
collected_cids |= sample.after_mip.extra[self.attr]
|
||||
instance.free()
|
||||
self.known_cids.clear()
|
||||
self.known_cids.extend(sorted(collected_cids))
|
||||
super().fit(training_instances)
|
||||
def pre_sample_xy(self, instance: Instance, sample: Sample) -> None:
|
||||
if (
|
||||
sample.after_mip is None
|
||||
or sample.after_mip.extra is None
|
||||
or sample.after_mip.extra[self.attr] is None
|
||||
):
|
||||
return
|
||||
self.known_cids.extend(sorted(sample.after_mip.extra[self.attr]))
|
||||
|
||||
@overrides
|
||||
def fit_xy(
|
||||
|
||||
Reference in New Issue
Block a user