Add pre argument to sample_xy

This commit is contained in:
2021-04-13 19:19:49 -05:00
parent a01c179341
commit bec7dae6d9
8 changed files with 64 additions and 33 deletions

View File

@@ -3,7 +3,7 @@
# Released under the modified BSD license. See COPYING.md for more details.
import logging
from typing import Dict, Hashable, List, Tuple, Optional
from typing import Dict, Hashable, List, Tuple, Optional, Any, FrozenSet, Set
import numpy as np
from overrides import overrides
@@ -89,7 +89,14 @@ class DynamicConstraintsComponent(Component):
self,
instance: Optional[Instance],
sample: Sample,
pre: Optional[List[Any]] = None,
) -> Tuple[Dict, Dict]:
assert pre is not None
known_cids: Set = set()
for cids in pre:
known_cids |= cids
self.known_cids.clear()
self.known_cids.extend(sorted(known_cids))
x, y, _ = self.sample_xy_with_cids(instance, sample)
return x, y
@@ -117,14 +124,14 @@ class DynamicConstraintsComponent(Component):
return pred
@overrides
def pre_sample_xy(self, instance: Instance, sample: Sample) -> None:
def pre_sample_xy(self, instance: Instance, sample: Sample) -> Any:
if (
sample.after_mip is None
or sample.after_mip.extra is None
or sample.after_mip.extra[self.attr] is None
):
return
self.known_cids.extend(sorted(sample.after_mip.extra[self.attr]))
return sample.after_mip.extra[self.attr]
@overrides
def fit_xy(