mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-07 18:08:51 -06:00
Add pre argument to sample_xy
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
import logging
|
||||
from typing import Any, TYPE_CHECKING, Hashable, Set, Tuple, Dict, List
|
||||
from typing import Any, TYPE_CHECKING, Hashable, Set, Tuple, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
from overrides import overrides
|
||||
@@ -101,8 +101,9 @@ class UserCutsComponent(Component):
|
||||
self,
|
||||
instance: "Instance",
|
||||
sample: Sample,
|
||||
pre: Optional[List[Any]] = None,
|
||||
) -> Tuple[Dict, Dict]:
|
||||
return self.dynamic.sample_xy(instance, sample)
|
||||
return self.dynamic.sample_xy(instance, sample, pre=pre)
|
||||
|
||||
def sample_predict(
|
||||
self,
|
||||
@@ -112,8 +113,8 @@ class UserCutsComponent(Component):
|
||||
return self.dynamic.sample_predict(instance, sample)
|
||||
|
||||
@overrides
|
||||
def pre_sample_xy(self, instance: Instance, sample: Sample) -> None:
|
||||
self.dynamic.pre_sample_xy(instance, sample)
|
||||
def pre_sample_xy(self, instance: Instance, sample: Sample) -> Any:
|
||||
return self.dynamic.pre_sample_xy(instance, sample)
|
||||
|
||||
@overrides
|
||||
def fit_xy(
|
||||
|
||||
Reference in New Issue
Block a user