mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-07 18:08:51 -06:00
Add pre argument to sample_xy
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
import logging
|
||||
from typing import Dict, Hashable, List, Tuple, Optional
|
||||
from typing import Dict, Hashable, List, Tuple, Optional, Any, FrozenSet, Set
|
||||
|
||||
import numpy as np
|
||||
from overrides import overrides
|
||||
@@ -89,7 +89,14 @@ class DynamicConstraintsComponent(Component):
|
||||
self,
|
||||
instance: Optional[Instance],
|
||||
sample: Sample,
|
||||
pre: Optional[List[Any]] = None,
|
||||
) -> Tuple[Dict, Dict]:
|
||||
assert pre is not None
|
||||
known_cids: Set = set()
|
||||
for cids in pre:
|
||||
known_cids |= cids
|
||||
self.known_cids.clear()
|
||||
self.known_cids.extend(sorted(known_cids))
|
||||
x, y, _ = self.sample_xy_with_cids(instance, sample)
|
||||
return x, y
|
||||
|
||||
@@ -117,14 +124,14 @@ class DynamicConstraintsComponent(Component):
|
||||
return pred
|
||||
|
||||
@overrides
|
||||
def pre_sample_xy(self, instance: Instance, sample: Sample) -> None:
|
||||
def pre_sample_xy(self, instance: Instance, sample: Sample) -> Any:
|
||||
if (
|
||||
sample.after_mip is None
|
||||
or sample.after_mip.extra is None
|
||||
or sample.after_mip.extra[self.attr] is None
|
||||
):
|
||||
return
|
||||
self.known_cids.extend(sorted(sample.after_mip.extra[self.attr]))
|
||||
return sample.after_mip.extra[self.attr]
|
||||
|
||||
@overrides
|
||||
def fit_xy(
|
||||
|
||||
Reference in New Issue
Block a user