diff --git a/miplearn/components/component.py b/miplearn/components/component.py index 46dd381..f445a63 100644 --- a/miplearn/components/component.py +++ b/miplearn/components/component.py @@ -159,6 +159,7 @@ class Component(EnforceOverrides): self, instance: Optional[Instance], sample: Sample, + pre: Optional[List[Any]] = None, ) -> Tuple[Dict, Dict]: """ Returns a pair of x and y dictionaries containing, respectively, the matrices @@ -175,7 +176,7 @@ class Component(EnforceOverrides): ) -> None: return - def pre_sample_xy(self, instance: Instance, sample: Sample) -> None: + def pre_sample_xy(self, instance: Instance, sample: Sample) -> Any: pass @staticmethod @@ -183,28 +184,29 @@ class Component(EnforceOverrides): components: Dict[str, "Component"], instances: List[Instance], ) -> None: - x_combined: Dict = {} - y_combined: Dict = {} - for (cname, comp) in components.items(): - x_combined[cname] = {} - y_combined[cname] = {} - - # pre_sample_xy - for instance in instances: + def _pre_sample_xy(instance: Instance) -> Dict: + pre_instance: Dict = {} + for (cname, comp) in components.items(): + pre_instance[cname] = [] instance.load() for sample in instance.samples: for (cname, comp) in components.items(): - comp.pre_sample_xy(instance, sample) + pre_instance[cname].append(comp.pre_sample_xy(instance, sample)) instance.free() - - # sample_xy - for instance in instances: + return pre_instance + + def _sample_xy(instance: Instance, pre: Dict) -> Tuple[Dict, Dict]: + x_instance: Dict = {} + y_instance: Dict = {} + for (cname, comp) in components.items(): + x_instance[cname] = {} + y_instance[cname] = {} instance.load() for sample in instance.samples: for (cname, comp) in components.items(): - x = x_combined[cname] - y = y_combined[cname] - x_sample, y_sample = comp.sample_xy(instance, sample) + x = x_instance[cname] + y = y_instance[cname] + x_sample, y_sample = comp.sample_xy(instance, sample, pre[cname]) for cat in x_sample.keys(): if cat not in x: x[cat] = [] @@ -212,12 +214,29 @@ class Component(EnforceOverrides): x[cat] += x_sample[cat] y[cat] += y_sample[cat] instance.free() + return x_instance, y_instance + + pre = [_pre_sample_xy(instance) for instance in instances] + + pre_combined: Dict = {} + for (cname, comp) in components.items(): + pre_combined[cname] = [] + for p in pre: + pre_combined[cname].extend(p[cname]) + + xy = [_sample_xy(instances, pre_combined) for instances in instances] - # fit_xy for (cname, comp) in components.items(): - x = x_combined[cname] - y = y_combined[cname] - for cat in x.keys(): - x[cat] = np.array(x[cat], dtype=np.float32) - y[cat] = np.array(y[cat]) - comp.fit_xy(x, y) + x_comp: Dict = {} + y_comp: Dict = {} + for (x, y) in xy: + for cat in x[cname].keys(): + if cat not in x_comp: + x_comp[cat] = [] + y_comp[cat] = [] + x_comp[cat].extend(x[cname][cat]) + y_comp[cat].extend(y[cname][cat]) + for cat in x_comp.keys(): + x_comp[cat] = np.array(x_comp[cat], dtype=np.float32) + y_comp[cat] = np.array(y_comp[cat]) + comp.fit_xy(x_comp, y_comp) diff --git a/miplearn/components/dynamic_common.py b/miplearn/components/dynamic_common.py index 4a698cc..561becf 100644 --- a/miplearn/components/dynamic_common.py +++ b/miplearn/components/dynamic_common.py @@ -3,7 +3,7 @@ # Released under the modified BSD license. See COPYING.md for more details. import logging -from typing import Dict, Hashable, List, Tuple, Optional +from typing import Dict, Hashable, List, Tuple, Optional, Any, FrozenSet, Set import numpy as np from overrides import overrides @@ -89,7 +89,14 @@ class DynamicConstraintsComponent(Component): self, instance: Optional[Instance], sample: Sample, + pre: Optional[List[Any]] = None, ) -> Tuple[Dict, Dict]: + assert pre is not None + known_cids: Set = set() + for cids in pre: + known_cids |= cids + self.known_cids.clear() + self.known_cids.extend(sorted(known_cids)) x, y, _ = self.sample_xy_with_cids(instance, sample) return x, y @@ -117,14 +124,14 @@ class DynamicConstraintsComponent(Component): return pred @overrides - def pre_sample_xy(self, instance: Instance, sample: Sample) -> None: + def pre_sample_xy(self, instance: Instance, sample: Sample) -> Any: if ( sample.after_mip is None or sample.after_mip.extra is None or sample.after_mip.extra[self.attr] is None ): return - self.known_cids.extend(sorted(sample.after_mip.extra[self.attr])) + return sample.after_mip.extra[self.attr] @overrides def fit_xy( diff --git a/miplearn/components/dynamic_lazy.py b/miplearn/components/dynamic_lazy.py index 3efb3c8..ab1d412 100644 --- a/miplearn/components/dynamic_lazy.py +++ b/miplearn/components/dynamic_lazy.py @@ -108,8 +108,9 @@ class DynamicLazyConstraintsComponent(Component): self, instance: Optional[Instance], sample: Sample, + pre: Optional[List[Any]] = None, ) -> Tuple[Dict, Dict]: - return self.dynamic.sample_xy(instance, sample) + return self.dynamic.sample_xy(instance, sample, pre=pre) def sample_predict( self, @@ -119,8 +120,8 @@ class DynamicLazyConstraintsComponent(Component): return self.dynamic.sample_predict(instance, sample) @overrides - def pre_sample_xy(self, instance: Instance, sample: Sample) -> None: - self.dynamic.pre_sample_xy(instance, sample) + def pre_sample_xy(self, instance: Instance, sample: Sample) -> Any: + return self.dynamic.pre_sample_xy(instance, sample) @overrides def fit_xy( diff --git a/miplearn/components/dynamic_user_cuts.py b/miplearn/components/dynamic_user_cuts.py index 87bed5e..ebf7214 100644 --- a/miplearn/components/dynamic_user_cuts.py +++ b/miplearn/components/dynamic_user_cuts.py @@ -3,7 +3,7 @@ # Released under the modified BSD license. See COPYING.md for more details. import logging -from typing import Any, TYPE_CHECKING, Hashable, Set, Tuple, Dict, List +from typing import Any, TYPE_CHECKING, Hashable, Set, Tuple, Dict, List, Optional import numpy as np from overrides import overrides @@ -101,8 +101,9 @@ class UserCutsComponent(Component): self, instance: "Instance", sample: Sample, + pre: Optional[List[Any]] = None, ) -> Tuple[Dict, Dict]: - return self.dynamic.sample_xy(instance, sample) + return self.dynamic.sample_xy(instance, sample, pre=pre) def sample_predict( self, @@ -112,8 +113,8 @@ class UserCutsComponent(Component): return self.dynamic.sample_predict(instance, sample) @overrides - def pre_sample_xy(self, instance: Instance, sample: Sample) -> None: - self.dynamic.pre_sample_xy(instance, sample) + def pre_sample_xy(self, instance: Instance, sample: Sample) -> Any: + return self.dynamic.pre_sample_xy(instance, sample) @overrides def fit_xy( diff --git a/miplearn/components/objective.py b/miplearn/components/objective.py index 2fc5afd..a226bc7 100644 --- a/miplearn/components/objective.py +++ b/miplearn/components/objective.py @@ -76,6 +76,7 @@ class ObjectiveValueComponent(Component): self, _: Optional[Instance], sample: Sample, + pre: Optional[List[Any]] = None, ) -> Tuple[Dict[Hashable, List[List[float]]], Dict[Hashable, List[List[float]]]]: # Instance features assert sample.after_load is not None diff --git a/miplearn/components/primal.py b/miplearn/components/primal.py index 00c676e..15359e3 100644 --- a/miplearn/components/primal.py +++ b/miplearn/components/primal.py @@ -145,6 +145,7 @@ class PrimalSolutionComponent(Component): self, _: Optional[Instance], sample: Sample, + pre: Optional[List[Any]] = None, ) -> Tuple[Dict[Category, List[List[float]]], Dict[Category, List[List[float]]]]: x: Dict = {} y: Dict = {} diff --git a/miplearn/components/static_lazy.py b/miplearn/components/static_lazy.py index ebf204b..6858e5f 100644 --- a/miplearn/components/static_lazy.py +++ b/miplearn/components/static_lazy.py @@ -154,6 +154,7 @@ class StaticLazyConstraintsComponent(Component): self, _: Optional[Instance], sample: Sample, + pre: Optional[List[Any]] = None, ) -> Tuple[Dict[Hashable, List[List[float]]], Dict[Hashable, List[List[float]]]]: x, y, __ = self._sample_xy_with_cids(sample) return x, y diff --git a/tests/components/test_dynamic_lazy.py b/tests/components/test_dynamic_lazy.py index f17ffe8..b96c03f 100644 --- a/tests/components/test_dynamic_lazy.py +++ b/tests/components/test_dynamic_lazy.py @@ -87,7 +87,6 @@ def training_instances() -> List[Instance]: def test_sample_xy(training_instances: List[Instance]) -> None: comp = DynamicLazyConstraintsComponent() - comp.dynamic.known_cids = ["c1", "c2", "c3", "c4"] x_expected = { "type-a": [[5.0, 1.0, 2.0, 3.0], [5.0, 4.0, 5.0, 6.0]], "type-b": [[5.0, 1.0, 2.0], [5.0, 3.0, 4.0]], @@ -99,6 +98,7 @@ def test_sample_xy(training_instances: List[Instance]) -> None: x_actual, y_actual = comp.sample_xy( training_instances[0], training_instances[0].samples[0], + pre=[{"c1", "c2", "c3", "c4"}], ) assert_equals(x_actual, x_expected) assert_equals(y_actual, y_expected)