Add pre argument to sample_xy

master
Alinson S. Xavier 5 years ago
parent a01c179341
commit bec7dae6d9
No known key found for this signature in database
GPG Key ID: DCA0DAD4D2F58624

@ -159,6 +159,7 @@ class Component(EnforceOverrides):
self, self,
instance: Optional[Instance], instance: Optional[Instance],
sample: Sample, sample: Sample,
pre: Optional[List[Any]] = None,
) -> Tuple[Dict, Dict]: ) -> Tuple[Dict, Dict]:
""" """
Returns a pair of x and y dictionaries containing, respectively, the matrices Returns a pair of x and y dictionaries containing, respectively, the matrices
@ -175,7 +176,7 @@ class Component(EnforceOverrides):
) -> None: ) -> None:
return return
def pre_sample_xy(self, instance: Instance, sample: Sample) -> None: def pre_sample_xy(self, instance: Instance, sample: Sample) -> Any:
pass pass
@staticmethod @staticmethod
@ -183,28 +184,29 @@ class Component(EnforceOverrides):
components: Dict[str, "Component"], components: Dict[str, "Component"],
instances: List[Instance], instances: List[Instance],
) -> None: ) -> None:
x_combined: Dict = {} def _pre_sample_xy(instance: Instance) -> Dict:
y_combined: Dict = {} pre_instance: Dict = {}
for (cname, comp) in components.items(): for (cname, comp) in components.items():
x_combined[cname] = {} pre_instance[cname] = []
y_combined[cname] = {}
# pre_sample_xy
for instance in instances:
instance.load() instance.load()
for sample in instance.samples: for sample in instance.samples:
for (cname, comp) in components.items(): for (cname, comp) in components.items():
comp.pre_sample_xy(instance, sample) pre_instance[cname].append(comp.pre_sample_xy(instance, sample))
instance.free() instance.free()
return pre_instance
# sample_xy def _sample_xy(instance: Instance, pre: Dict) -> Tuple[Dict, Dict]:
for instance in instances: x_instance: Dict = {}
y_instance: Dict = {}
for (cname, comp) in components.items():
x_instance[cname] = {}
y_instance[cname] = {}
instance.load() instance.load()
for sample in instance.samples: for sample in instance.samples:
for (cname, comp) in components.items(): for (cname, comp) in components.items():
x = x_combined[cname] x = x_instance[cname]
y = y_combined[cname] y = y_instance[cname]
x_sample, y_sample = comp.sample_xy(instance, sample) x_sample, y_sample = comp.sample_xy(instance, sample, pre[cname])
for cat in x_sample.keys(): for cat in x_sample.keys():
if cat not in x: if cat not in x:
x[cat] = [] x[cat] = []
@ -212,12 +214,29 @@ class Component(EnforceOverrides):
x[cat] += x_sample[cat] x[cat] += x_sample[cat]
y[cat] += y_sample[cat] y[cat] += y_sample[cat]
instance.free() instance.free()
return x_instance, y_instance
pre = [_pre_sample_xy(instance) for instance in instances]
pre_combined: Dict = {}
for (cname, comp) in components.items():
pre_combined[cname] = []
for p in pre:
pre_combined[cname].extend(p[cname])
xy = [_sample_xy(instances, pre_combined) for instances in instances]
# fit_xy
for (cname, comp) in components.items(): for (cname, comp) in components.items():
x = x_combined[cname] x_comp: Dict = {}
y = y_combined[cname] y_comp: Dict = {}
for cat in x.keys(): for (x, y) in xy:
x[cat] = np.array(x[cat], dtype=np.float32) for cat in x[cname].keys():
y[cat] = np.array(y[cat]) if cat not in x_comp:
comp.fit_xy(x, y) x_comp[cat] = []
y_comp[cat] = []
x_comp[cat].extend(x[cname][cat])
y_comp[cat].extend(y[cname][cat])
for cat in x_comp.keys():
x_comp[cat] = np.array(x_comp[cat], dtype=np.float32)
y_comp[cat] = np.array(y_comp[cat])
comp.fit_xy(x_comp, y_comp)

@ -3,7 +3,7 @@
# Released under the modified BSD license. See COPYING.md for more details. # Released under the modified BSD license. See COPYING.md for more details.
import logging import logging
from typing import Dict, Hashable, List, Tuple, Optional from typing import Dict, Hashable, List, Tuple, Optional, Any, FrozenSet, Set
import numpy as np import numpy as np
from overrides import overrides from overrides import overrides
@ -89,7 +89,14 @@ class DynamicConstraintsComponent(Component):
self, self,
instance: Optional[Instance], instance: Optional[Instance],
sample: Sample, sample: Sample,
pre: Optional[List[Any]] = None,
) -> Tuple[Dict, Dict]: ) -> Tuple[Dict, Dict]:
assert pre is not None
known_cids: Set = set()
for cids in pre:
known_cids |= cids
self.known_cids.clear()
self.known_cids.extend(sorted(known_cids))
x, y, _ = self.sample_xy_with_cids(instance, sample) x, y, _ = self.sample_xy_with_cids(instance, sample)
return x, y return x, y
@ -117,14 +124,14 @@ class DynamicConstraintsComponent(Component):
return pred return pred
@overrides @overrides
def pre_sample_xy(self, instance: Instance, sample: Sample) -> None: def pre_sample_xy(self, instance: Instance, sample: Sample) -> Any:
if ( if (
sample.after_mip is None sample.after_mip is None
or sample.after_mip.extra is None or sample.after_mip.extra is None
or sample.after_mip.extra[self.attr] is None or sample.after_mip.extra[self.attr] is None
): ):
return return
self.known_cids.extend(sorted(sample.after_mip.extra[self.attr])) return sample.after_mip.extra[self.attr]
@overrides @overrides
def fit_xy( def fit_xy(

@ -108,8 +108,9 @@ class DynamicLazyConstraintsComponent(Component):
self, self,
instance: Optional[Instance], instance: Optional[Instance],
sample: Sample, sample: Sample,
pre: Optional[List[Any]] = None,
) -> Tuple[Dict, Dict]: ) -> Tuple[Dict, Dict]:
return self.dynamic.sample_xy(instance, sample) return self.dynamic.sample_xy(instance, sample, pre=pre)
def sample_predict( def sample_predict(
self, self,
@ -119,8 +120,8 @@ class DynamicLazyConstraintsComponent(Component):
return self.dynamic.sample_predict(instance, sample) return self.dynamic.sample_predict(instance, sample)
@overrides @overrides
def pre_sample_xy(self, instance: Instance, sample: Sample) -> None: def pre_sample_xy(self, instance: Instance, sample: Sample) -> Any:
self.dynamic.pre_sample_xy(instance, sample) return self.dynamic.pre_sample_xy(instance, sample)
@overrides @overrides
def fit_xy( def fit_xy(

@ -3,7 +3,7 @@
# Released under the modified BSD license. See COPYING.md for more details. # Released under the modified BSD license. See COPYING.md for more details.
import logging import logging
from typing import Any, TYPE_CHECKING, Hashable, Set, Tuple, Dict, List from typing import Any, TYPE_CHECKING, Hashable, Set, Tuple, Dict, List, Optional
import numpy as np import numpy as np
from overrides import overrides from overrides import overrides
@ -101,8 +101,9 @@ class UserCutsComponent(Component):
self, self,
instance: "Instance", instance: "Instance",
sample: Sample, sample: Sample,
pre: Optional[List[Any]] = None,
) -> Tuple[Dict, Dict]: ) -> Tuple[Dict, Dict]:
return self.dynamic.sample_xy(instance, sample) return self.dynamic.sample_xy(instance, sample, pre=pre)
def sample_predict( def sample_predict(
self, self,
@ -112,8 +113,8 @@ class UserCutsComponent(Component):
return self.dynamic.sample_predict(instance, sample) return self.dynamic.sample_predict(instance, sample)
@overrides @overrides
def pre_sample_xy(self, instance: Instance, sample: Sample) -> None: def pre_sample_xy(self, instance: Instance, sample: Sample) -> Any:
self.dynamic.pre_sample_xy(instance, sample) return self.dynamic.pre_sample_xy(instance, sample)
@overrides @overrides
def fit_xy( def fit_xy(

@ -76,6 +76,7 @@ class ObjectiveValueComponent(Component):
self, self,
_: Optional[Instance], _: Optional[Instance],
sample: Sample, sample: Sample,
pre: Optional[List[Any]] = None,
) -> Tuple[Dict[Hashable, List[List[float]]], Dict[Hashable, List[List[float]]]]: ) -> Tuple[Dict[Hashable, List[List[float]]], Dict[Hashable, List[List[float]]]]:
# Instance features # Instance features
assert sample.after_load is not None assert sample.after_load is not None

@ -145,6 +145,7 @@ class PrimalSolutionComponent(Component):
self, self,
_: Optional[Instance], _: Optional[Instance],
sample: Sample, sample: Sample,
pre: Optional[List[Any]] = None,
) -> Tuple[Dict[Category, List[List[float]]], Dict[Category, List[List[float]]]]: ) -> Tuple[Dict[Category, List[List[float]]], Dict[Category, List[List[float]]]]:
x: Dict = {} x: Dict = {}
y: Dict = {} y: Dict = {}

@ -154,6 +154,7 @@ class StaticLazyConstraintsComponent(Component):
self, self,
_: Optional[Instance], _: Optional[Instance],
sample: Sample, sample: Sample,
pre: Optional[List[Any]] = None,
) -> Tuple[Dict[Hashable, List[List[float]]], Dict[Hashable, List[List[float]]]]: ) -> Tuple[Dict[Hashable, List[List[float]]], Dict[Hashable, List[List[float]]]]:
x, y, __ = self._sample_xy_with_cids(sample) x, y, __ = self._sample_xy_with_cids(sample)
return x, y return x, y

@ -87,7 +87,6 @@ def training_instances() -> List[Instance]:
def test_sample_xy(training_instances: List[Instance]) -> None: def test_sample_xy(training_instances: List[Instance]) -> None:
comp = DynamicLazyConstraintsComponent() comp = DynamicLazyConstraintsComponent()
comp.dynamic.known_cids = ["c1", "c2", "c3", "c4"]
x_expected = { x_expected = {
"type-a": [[5.0, 1.0, 2.0, 3.0], [5.0, 4.0, 5.0, 6.0]], "type-a": [[5.0, 1.0, 2.0, 3.0], [5.0, 4.0, 5.0, 6.0]],
"type-b": [[5.0, 1.0, 2.0], [5.0, 3.0, 4.0]], "type-b": [[5.0, 1.0, 2.0], [5.0, 3.0, 4.0]],
@ -99,6 +98,7 @@ def test_sample_xy(training_instances: List[Instance]) -> None:
x_actual, y_actual = comp.sample_xy( x_actual, y_actual = comp.sample_xy(
training_instances[0], training_instances[0],
training_instances[0].samples[0], training_instances[0].samples[0],
pre=[{"c1", "c2", "c3", "c4"}],
) )
assert_equals(x_actual, x_expected) assert_equals(x_actual, x_expected)
assert_equals(y_actual, y_expected) assert_equals(y_actual, y_expected)

Loading…
Cancel
Save