Add pre argument to sample_xy

master
Alinson S. Xavier 5 years ago
parent a01c179341
commit bec7dae6d9
No known key found for this signature in database
GPG Key ID: DCA0DAD4D2F58624

@ -159,6 +159,7 @@ class Component(EnforceOverrides):
self,
instance: Optional[Instance],
sample: Sample,
pre: Optional[List[Any]] = None,
) -> Tuple[Dict, Dict]:
"""
Returns a pair of x and y dictionaries containing, respectively, the matrices
@ -175,7 +176,7 @@ class Component(EnforceOverrides):
) -> None:
return
def pre_sample_xy(self, instance: Instance, sample: Sample) -> None:
def pre_sample_xy(self, instance: Instance, sample: Sample) -> Any:
pass
@staticmethod
@ -183,28 +184,29 @@ class Component(EnforceOverrides):
components: Dict[str, "Component"],
instances: List[Instance],
) -> None:
x_combined: Dict = {}
y_combined: Dict = {}
def _pre_sample_xy(instance: Instance) -> Dict:
pre_instance: Dict = {}
for (cname, comp) in components.items():
x_combined[cname] = {}
y_combined[cname] = {}
# pre_sample_xy
for instance in instances:
pre_instance[cname] = []
instance.load()
for sample in instance.samples:
for (cname, comp) in components.items():
comp.pre_sample_xy(instance, sample)
pre_instance[cname].append(comp.pre_sample_xy(instance, sample))
instance.free()
return pre_instance
# sample_xy
for instance in instances:
def _sample_xy(instance: Instance, pre: Dict) -> Tuple[Dict, Dict]:
x_instance: Dict = {}
y_instance: Dict = {}
for (cname, comp) in components.items():
x_instance[cname] = {}
y_instance[cname] = {}
instance.load()
for sample in instance.samples:
for (cname, comp) in components.items():
x = x_combined[cname]
y = y_combined[cname]
x_sample, y_sample = comp.sample_xy(instance, sample)
x = x_instance[cname]
y = y_instance[cname]
x_sample, y_sample = comp.sample_xy(instance, sample, pre[cname])
for cat in x_sample.keys():
if cat not in x:
x[cat] = []
@ -212,12 +214,29 @@ class Component(EnforceOverrides):
x[cat] += x_sample[cat]
y[cat] += y_sample[cat]
instance.free()
return x_instance, y_instance
pre = [_pre_sample_xy(instance) for instance in instances]
pre_combined: Dict = {}
for (cname, comp) in components.items():
pre_combined[cname] = []
for p in pre:
pre_combined[cname].extend(p[cname])
xy = [_sample_xy(instances, pre_combined) for instances in instances]
# fit_xy
for (cname, comp) in components.items():
x = x_combined[cname]
y = y_combined[cname]
for cat in x.keys():
x[cat] = np.array(x[cat], dtype=np.float32)
y[cat] = np.array(y[cat])
comp.fit_xy(x, y)
x_comp: Dict = {}
y_comp: Dict = {}
for (x, y) in xy:
for cat in x[cname].keys():
if cat not in x_comp:
x_comp[cat] = []
y_comp[cat] = []
x_comp[cat].extend(x[cname][cat])
y_comp[cat].extend(y[cname][cat])
for cat in x_comp.keys():
x_comp[cat] = np.array(x_comp[cat], dtype=np.float32)
y_comp[cat] = np.array(y_comp[cat])
comp.fit_xy(x_comp, y_comp)

@ -3,7 +3,7 @@
# Released under the modified BSD license. See COPYING.md for more details.
import logging
from typing import Dict, Hashable, List, Tuple, Optional
from typing import Dict, Hashable, List, Tuple, Optional, Any, FrozenSet, Set
import numpy as np
from overrides import overrides
@ -89,7 +89,14 @@ class DynamicConstraintsComponent(Component):
self,
instance: Optional[Instance],
sample: Sample,
pre: Optional[List[Any]] = None,
) -> Tuple[Dict, Dict]:
assert pre is not None
known_cids: Set = set()
for cids in pre:
known_cids |= cids
self.known_cids.clear()
self.known_cids.extend(sorted(known_cids))
x, y, _ = self.sample_xy_with_cids(instance, sample)
return x, y
@ -117,14 +124,14 @@ class DynamicConstraintsComponent(Component):
return pred
@overrides
def pre_sample_xy(self, instance: Instance, sample: Sample) -> None:
def pre_sample_xy(self, instance: Instance, sample: Sample) -> Any:
if (
sample.after_mip is None
or sample.after_mip.extra is None
or sample.after_mip.extra[self.attr] is None
):
return
self.known_cids.extend(sorted(sample.after_mip.extra[self.attr]))
return sample.after_mip.extra[self.attr]
@overrides
def fit_xy(

@ -108,8 +108,9 @@ class DynamicLazyConstraintsComponent(Component):
self,
instance: Optional[Instance],
sample: Sample,
pre: Optional[List[Any]] = None,
) -> Tuple[Dict, Dict]:
return self.dynamic.sample_xy(instance, sample)
return self.dynamic.sample_xy(instance, sample, pre=pre)
def sample_predict(
self,
@ -119,8 +120,8 @@ class DynamicLazyConstraintsComponent(Component):
return self.dynamic.sample_predict(instance, sample)
@overrides
def pre_sample_xy(self, instance: Instance, sample: Sample) -> None:
self.dynamic.pre_sample_xy(instance, sample)
def pre_sample_xy(self, instance: Instance, sample: Sample) -> Any:
return self.dynamic.pre_sample_xy(instance, sample)
@overrides
def fit_xy(

@ -3,7 +3,7 @@
# Released under the modified BSD license. See COPYING.md for more details.
import logging
from typing import Any, TYPE_CHECKING, Hashable, Set, Tuple, Dict, List
from typing import Any, TYPE_CHECKING, Hashable, Set, Tuple, Dict, List, Optional
import numpy as np
from overrides import overrides
@ -101,8 +101,9 @@ class UserCutsComponent(Component):
self,
instance: "Instance",
sample: Sample,
pre: Optional[List[Any]] = None,
) -> Tuple[Dict, Dict]:
return self.dynamic.sample_xy(instance, sample)
return self.dynamic.sample_xy(instance, sample, pre=pre)
def sample_predict(
self,
@ -112,8 +113,8 @@ class UserCutsComponent(Component):
return self.dynamic.sample_predict(instance, sample)
@overrides
def pre_sample_xy(self, instance: Instance, sample: Sample) -> None:
self.dynamic.pre_sample_xy(instance, sample)
def pre_sample_xy(self, instance: Instance, sample: Sample) -> Any:
return self.dynamic.pre_sample_xy(instance, sample)
@overrides
def fit_xy(

@ -76,6 +76,7 @@ class ObjectiveValueComponent(Component):
self,
_: Optional[Instance],
sample: Sample,
pre: Optional[List[Any]] = None,
) -> Tuple[Dict[Hashable, List[List[float]]], Dict[Hashable, List[List[float]]]]:
# Instance features
assert sample.after_load is not None

@ -145,6 +145,7 @@ class PrimalSolutionComponent(Component):
self,
_: Optional[Instance],
sample: Sample,
pre: Optional[List[Any]] = None,
) -> Tuple[Dict[Category, List[List[float]]], Dict[Category, List[List[float]]]]:
x: Dict = {}
y: Dict = {}

@ -154,6 +154,7 @@ class StaticLazyConstraintsComponent(Component):
self,
_: Optional[Instance],
sample: Sample,
pre: Optional[List[Any]] = None,
) -> Tuple[Dict[Hashable, List[List[float]]], Dict[Hashable, List[List[float]]]]:
x, y, __ = self._sample_xy_with_cids(sample)
return x, y

@ -87,7 +87,6 @@ def training_instances() -> List[Instance]:
def test_sample_xy(training_instances: List[Instance]) -> None:
comp = DynamicLazyConstraintsComponent()
comp.dynamic.known_cids = ["c1", "c2", "c3", "c4"]
x_expected = {
"type-a": [[5.0, 1.0, 2.0, 3.0], [5.0, 4.0, 5.0, 6.0]],
"type-b": [[5.0, 1.0, 2.0], [5.0, 3.0, 4.0]],
@ -99,6 +98,7 @@ def test_sample_xy(training_instances: List[Instance]) -> None:
x_actual, y_actual = comp.sample_xy(
training_instances[0],
training_instances[0].samples[0],
pre=[{"c1", "c2", "c3", "c4"}],
)
assert_equals(x_actual, x_expected)
assert_equals(y_actual, y_expected)

Loading…
Cancel
Save