Add pre argument to sample_xy

This commit is contained in:
2021-04-13 19:19:49 -05:00
parent a01c179341
commit bec7dae6d9
8 changed files with 64 additions and 33 deletions

View File

@@ -159,6 +159,7 @@ class Component(EnforceOverrides):
self, self,
instance: Optional[Instance], instance: Optional[Instance],
sample: Sample, sample: Sample,
pre: Optional[List[Any]] = None,
) -> Tuple[Dict, Dict]: ) -> Tuple[Dict, Dict]:
""" """
Returns a pair of x and y dictionaries containing, respectively, the matrices Returns a pair of x and y dictionaries containing, respectively, the matrices
@@ -175,7 +176,7 @@ class Component(EnforceOverrides):
) -> None: ) -> None:
return return
def pre_sample_xy(self, instance: Instance, sample: Sample) -> None: def pre_sample_xy(self, instance: Instance, sample: Sample) -> Any:
pass pass
@staticmethod @staticmethod
@@ -183,28 +184,29 @@ class Component(EnforceOverrides):
components: Dict[str, "Component"], components: Dict[str, "Component"],
instances: List[Instance], instances: List[Instance],
) -> None: ) -> None:
x_combined: Dict = {} def _pre_sample_xy(instance: Instance) -> Dict:
y_combined: Dict = {} pre_instance: Dict = {}
for (cname, comp) in components.items(): for (cname, comp) in components.items():
x_combined[cname] = {} pre_instance[cname] = []
y_combined[cname] = {}
# pre_sample_xy
for instance in instances:
instance.load() instance.load()
for sample in instance.samples: for sample in instance.samples:
for (cname, comp) in components.items(): for (cname, comp) in components.items():
comp.pre_sample_xy(instance, sample) pre_instance[cname].append(comp.pre_sample_xy(instance, sample))
instance.free() instance.free()
return pre_instance
# sample_xy def _sample_xy(instance: Instance, pre: Dict) -> Tuple[Dict, Dict]:
for instance in instances: x_instance: Dict = {}
y_instance: Dict = {}
for (cname, comp) in components.items():
x_instance[cname] = {}
y_instance[cname] = {}
instance.load() instance.load()
for sample in instance.samples: for sample in instance.samples:
for (cname, comp) in components.items(): for (cname, comp) in components.items():
x = x_combined[cname] x = x_instance[cname]
y = y_combined[cname] y = y_instance[cname]
x_sample, y_sample = comp.sample_xy(instance, sample) x_sample, y_sample = comp.sample_xy(instance, sample, pre[cname])
for cat in x_sample.keys(): for cat in x_sample.keys():
if cat not in x: if cat not in x:
x[cat] = [] x[cat] = []
@@ -212,12 +214,29 @@ class Component(EnforceOverrides):
x[cat] += x_sample[cat] x[cat] += x_sample[cat]
y[cat] += y_sample[cat] y[cat] += y_sample[cat]
instance.free() instance.free()
return x_instance, y_instance
# fit_xy pre = [_pre_sample_xy(instance) for instance in instances]
pre_combined: Dict = {}
for (cname, comp) in components.items(): for (cname, comp) in components.items():
x = x_combined[cname] pre_combined[cname] = []
y = y_combined[cname] for p in pre:
for cat in x.keys(): pre_combined[cname].extend(p[cname])
x[cat] = np.array(x[cat], dtype=np.float32)
y[cat] = np.array(y[cat]) xy = [_sample_xy(instances, pre_combined) for instances in instances]
comp.fit_xy(x, y)
for (cname, comp) in components.items():
x_comp: Dict = {}
y_comp: Dict = {}
for (x, y) in xy:
for cat in x[cname].keys():
if cat not in x_comp:
x_comp[cat] = []
y_comp[cat] = []
x_comp[cat].extend(x[cname][cat])
y_comp[cat].extend(y[cname][cat])
for cat in x_comp.keys():
x_comp[cat] = np.array(x_comp[cat], dtype=np.float32)
y_comp[cat] = np.array(y_comp[cat])
comp.fit_xy(x_comp, y_comp)

View File

@@ -3,7 +3,7 @@
# Released under the modified BSD license. See COPYING.md for more details. # Released under the modified BSD license. See COPYING.md for more details.
import logging import logging
from typing import Dict, Hashable, List, Tuple, Optional from typing import Dict, Hashable, List, Tuple, Optional, Any, FrozenSet, Set
import numpy as np import numpy as np
from overrides import overrides from overrides import overrides
@@ -89,7 +89,14 @@ class DynamicConstraintsComponent(Component):
self, self,
instance: Optional[Instance], instance: Optional[Instance],
sample: Sample, sample: Sample,
pre: Optional[List[Any]] = None,
) -> Tuple[Dict, Dict]: ) -> Tuple[Dict, Dict]:
assert pre is not None
known_cids: Set = set()
for cids in pre:
known_cids |= cids
self.known_cids.clear()
self.known_cids.extend(sorted(known_cids))
x, y, _ = self.sample_xy_with_cids(instance, sample) x, y, _ = self.sample_xy_with_cids(instance, sample)
return x, y return x, y
@@ -117,14 +124,14 @@ class DynamicConstraintsComponent(Component):
return pred return pred
@overrides @overrides
def pre_sample_xy(self, instance: Instance, sample: Sample) -> None: def pre_sample_xy(self, instance: Instance, sample: Sample) -> Any:
if ( if (
sample.after_mip is None sample.after_mip is None
or sample.after_mip.extra is None or sample.after_mip.extra is None
or sample.after_mip.extra[self.attr] is None or sample.after_mip.extra[self.attr] is None
): ):
return return
self.known_cids.extend(sorted(sample.after_mip.extra[self.attr])) return sample.after_mip.extra[self.attr]
@overrides @overrides
def fit_xy( def fit_xy(

View File

@@ -108,8 +108,9 @@ class DynamicLazyConstraintsComponent(Component):
self, self,
instance: Optional[Instance], instance: Optional[Instance],
sample: Sample, sample: Sample,
pre: Optional[List[Any]] = None,
) -> Tuple[Dict, Dict]: ) -> Tuple[Dict, Dict]:
return self.dynamic.sample_xy(instance, sample) return self.dynamic.sample_xy(instance, sample, pre=pre)
def sample_predict( def sample_predict(
self, self,
@@ -119,8 +120,8 @@ class DynamicLazyConstraintsComponent(Component):
return self.dynamic.sample_predict(instance, sample) return self.dynamic.sample_predict(instance, sample)
@overrides @overrides
def pre_sample_xy(self, instance: Instance, sample: Sample) -> None: def pre_sample_xy(self, instance: Instance, sample: Sample) -> Any:
self.dynamic.pre_sample_xy(instance, sample) return self.dynamic.pre_sample_xy(instance, sample)
@overrides @overrides
def fit_xy( def fit_xy(

View File

@@ -3,7 +3,7 @@
# Released under the modified BSD license. See COPYING.md for more details. # Released under the modified BSD license. See COPYING.md for more details.
import logging import logging
from typing import Any, TYPE_CHECKING, Hashable, Set, Tuple, Dict, List from typing import Any, TYPE_CHECKING, Hashable, Set, Tuple, Dict, List, Optional
import numpy as np import numpy as np
from overrides import overrides from overrides import overrides
@@ -101,8 +101,9 @@ class UserCutsComponent(Component):
self, self,
instance: "Instance", instance: "Instance",
sample: Sample, sample: Sample,
pre: Optional[List[Any]] = None,
) -> Tuple[Dict, Dict]: ) -> Tuple[Dict, Dict]:
return self.dynamic.sample_xy(instance, sample) return self.dynamic.sample_xy(instance, sample, pre=pre)
def sample_predict( def sample_predict(
self, self,
@@ -112,8 +113,8 @@ class UserCutsComponent(Component):
return self.dynamic.sample_predict(instance, sample) return self.dynamic.sample_predict(instance, sample)
@overrides @overrides
def pre_sample_xy(self, instance: Instance, sample: Sample) -> None: def pre_sample_xy(self, instance: Instance, sample: Sample) -> Any:
self.dynamic.pre_sample_xy(instance, sample) return self.dynamic.pre_sample_xy(instance, sample)
@overrides @overrides
def fit_xy( def fit_xy(

View File

@@ -76,6 +76,7 @@ class ObjectiveValueComponent(Component):
self, self,
_: Optional[Instance], _: Optional[Instance],
sample: Sample, sample: Sample,
pre: Optional[List[Any]] = None,
) -> Tuple[Dict[Hashable, List[List[float]]], Dict[Hashable, List[List[float]]]]: ) -> Tuple[Dict[Hashable, List[List[float]]], Dict[Hashable, List[List[float]]]]:
# Instance features # Instance features
assert sample.after_load is not None assert sample.after_load is not None

View File

@@ -145,6 +145,7 @@ class PrimalSolutionComponent(Component):
self, self,
_: Optional[Instance], _: Optional[Instance],
sample: Sample, sample: Sample,
pre: Optional[List[Any]] = None,
) -> Tuple[Dict[Category, List[List[float]]], Dict[Category, List[List[float]]]]: ) -> Tuple[Dict[Category, List[List[float]]], Dict[Category, List[List[float]]]]:
x: Dict = {} x: Dict = {}
y: Dict = {} y: Dict = {}

View File

@@ -154,6 +154,7 @@ class StaticLazyConstraintsComponent(Component):
self, self,
_: Optional[Instance], _: Optional[Instance],
sample: Sample, sample: Sample,
pre: Optional[List[Any]] = None,
) -> Tuple[Dict[Hashable, List[List[float]]], Dict[Hashable, List[List[float]]]]: ) -> Tuple[Dict[Hashable, List[List[float]]], Dict[Hashable, List[List[float]]]]:
x, y, __ = self._sample_xy_with_cids(sample) x, y, __ = self._sample_xy_with_cids(sample)
return x, y return x, y

View File

@@ -87,7 +87,6 @@ def training_instances() -> List[Instance]:
def test_sample_xy(training_instances: List[Instance]) -> None: def test_sample_xy(training_instances: List[Instance]) -> None:
comp = DynamicLazyConstraintsComponent() comp = DynamicLazyConstraintsComponent()
comp.dynamic.known_cids = ["c1", "c2", "c3", "c4"]
x_expected = { x_expected = {
"type-a": [[5.0, 1.0, 2.0, 3.0], [5.0, 4.0, 5.0, 6.0]], "type-a": [[5.0, 1.0, 2.0, 3.0], [5.0, 4.0, 5.0, 6.0]],
"type-b": [[5.0, 1.0, 2.0], [5.0, 3.0, 4.0]], "type-b": [[5.0, 1.0, 2.0], [5.0, 3.0, 4.0]],
@@ -99,6 +98,7 @@ def test_sample_xy(training_instances: List[Instance]) -> None:
x_actual, y_actual = comp.sample_xy( x_actual, y_actual = comp.sample_xy(
training_instances[0], training_instances[0],
training_instances[0].samples[0], training_instances[0].samples[0],
pre=[{"c1", "c2", "c3", "c4"}],
) )
assert_equals(x_actual, x_expected) assert_equals(x_actual, x_expected)
assert_equals(y_actual, y_expected) assert_equals(y_actual, y_expected)