Add pre argument to sample_xy

This commit is contained in:
2021-04-13 19:19:49 -05:00
parent a01c179341
commit bec7dae6d9
8 changed files with 64 additions and 33 deletions

View File

@@ -159,6 +159,7 @@ class Component(EnforceOverrides):
self,
instance: Optional[Instance],
sample: Sample,
pre: Optional[List[Any]] = None,
) -> Tuple[Dict, Dict]:
"""
Returns a pair of x and y dictionaries containing, respectively, the matrices
@@ -175,7 +176,7 @@ class Component(EnforceOverrides):
) -> None:
return
def pre_sample_xy(self, instance: Instance, sample: Sample) -> None:
def pre_sample_xy(self, instance: Instance, sample: Sample) -> Any:
pass
@staticmethod
@@ -183,28 +184,29 @@ class Component(EnforceOverrides):
components: Dict[str, "Component"],
instances: List[Instance],
) -> None:
x_combined: Dict = {}
y_combined: Dict = {}
for (cname, comp) in components.items():
x_combined[cname] = {}
y_combined[cname] = {}
# pre_sample_xy
for instance in instances:
def _pre_sample_xy(instance: Instance) -> Dict:
pre_instance: Dict = {}
for (cname, comp) in components.items():
pre_instance[cname] = []
instance.load()
for sample in instance.samples:
for (cname, comp) in components.items():
comp.pre_sample_xy(instance, sample)
pre_instance[cname].append(comp.pre_sample_xy(instance, sample))
instance.free()
return pre_instance
# sample_xy
for instance in instances:
def _sample_xy(instance: Instance, pre: Dict) -> Tuple[Dict, Dict]:
x_instance: Dict = {}
y_instance: Dict = {}
for (cname, comp) in components.items():
x_instance[cname] = {}
y_instance[cname] = {}
instance.load()
for sample in instance.samples:
for (cname, comp) in components.items():
x = x_combined[cname]
y = y_combined[cname]
x_sample, y_sample = comp.sample_xy(instance, sample)
x = x_instance[cname]
y = y_instance[cname]
x_sample, y_sample = comp.sample_xy(instance, sample, pre[cname])
for cat in x_sample.keys():
if cat not in x:
x[cat] = []
@@ -212,12 +214,29 @@ class Component(EnforceOverrides):
x[cat] += x_sample[cat]
y[cat] += y_sample[cat]
instance.free()
return x_instance, y_instance
# fit_xy
pre = [_pre_sample_xy(instance) for instance in instances]
pre_combined: Dict = {}
for (cname, comp) in components.items():
x = x_combined[cname]
y = y_combined[cname]
for cat in x.keys():
x[cat] = np.array(x[cat], dtype=np.float32)
y[cat] = np.array(y[cat])
comp.fit_xy(x, y)
pre_combined[cname] = []
for p in pre:
pre_combined[cname].extend(p[cname])
xy = [_sample_xy(instances, pre_combined) for instances in instances]
for (cname, comp) in components.items():
x_comp: Dict = {}
y_comp: Dict = {}
for (x, y) in xy:
for cat in x[cname].keys():
if cat not in x_comp:
x_comp[cat] = []
y_comp[cat] = []
x_comp[cat].extend(x[cname][cat])
y_comp[cat].extend(y[cname][cat])
for cat in x_comp.keys():
x_comp[cat] = np.array(x_comp[cat], dtype=np.float32)
y_comp[cat] = np.array(y_comp[cat])
comp.fit_xy(x_comp, y_comp)

View File

@@ -3,7 +3,7 @@
# Released under the modified BSD license. See COPYING.md for more details.
import logging
from typing import Dict, Hashable, List, Tuple, Optional
from typing import Dict, Hashable, List, Tuple, Optional, Any, FrozenSet, Set
import numpy as np
from overrides import overrides
@@ -89,7 +89,14 @@ class DynamicConstraintsComponent(Component):
self,
instance: Optional[Instance],
sample: Sample,
pre: Optional[List[Any]] = None,
) -> Tuple[Dict, Dict]:
assert pre is not None
known_cids: Set = set()
for cids in pre:
known_cids |= cids
self.known_cids.clear()
self.known_cids.extend(sorted(known_cids))
x, y, _ = self.sample_xy_with_cids(instance, sample)
return x, y
@@ -117,14 +124,14 @@ class DynamicConstraintsComponent(Component):
return pred
@overrides
def pre_sample_xy(self, instance: Instance, sample: Sample) -> None:
def pre_sample_xy(self, instance: Instance, sample: Sample) -> Any:
if (
sample.after_mip is None
or sample.after_mip.extra is None
or sample.after_mip.extra[self.attr] is None
):
return
self.known_cids.extend(sorted(sample.after_mip.extra[self.attr]))
return sample.after_mip.extra[self.attr]
@overrides
def fit_xy(

View File

@@ -108,8 +108,9 @@ class DynamicLazyConstraintsComponent(Component):
self,
instance: Optional[Instance],
sample: Sample,
pre: Optional[List[Any]] = None,
) -> Tuple[Dict, Dict]:
return self.dynamic.sample_xy(instance, sample)
return self.dynamic.sample_xy(instance, sample, pre=pre)
def sample_predict(
self,
@@ -119,8 +120,8 @@ class DynamicLazyConstraintsComponent(Component):
return self.dynamic.sample_predict(instance, sample)
@overrides
def pre_sample_xy(self, instance: Instance, sample: Sample) -> None:
self.dynamic.pre_sample_xy(instance, sample)
def pre_sample_xy(self, instance: Instance, sample: Sample) -> Any:
return self.dynamic.pre_sample_xy(instance, sample)
@overrides
def fit_xy(

View File

@@ -3,7 +3,7 @@
# Released under the modified BSD license. See COPYING.md for more details.
import logging
from typing import Any, TYPE_CHECKING, Hashable, Set, Tuple, Dict, List
from typing import Any, TYPE_CHECKING, Hashable, Set, Tuple, Dict, List, Optional
import numpy as np
from overrides import overrides
@@ -101,8 +101,9 @@ class UserCutsComponent(Component):
self,
instance: "Instance",
sample: Sample,
pre: Optional[List[Any]] = None,
) -> Tuple[Dict, Dict]:
return self.dynamic.sample_xy(instance, sample)
return self.dynamic.sample_xy(instance, sample, pre=pre)
def sample_predict(
self,
@@ -112,8 +113,8 @@ class UserCutsComponent(Component):
return self.dynamic.sample_predict(instance, sample)
@overrides
def pre_sample_xy(self, instance: Instance, sample: Sample) -> None:
self.dynamic.pre_sample_xy(instance, sample)
def pre_sample_xy(self, instance: Instance, sample: Sample) -> Any:
return self.dynamic.pre_sample_xy(instance, sample)
@overrides
def fit_xy(

View File

@@ -76,6 +76,7 @@ class ObjectiveValueComponent(Component):
self,
_: Optional[Instance],
sample: Sample,
pre: Optional[List[Any]] = None,
) -> Tuple[Dict[Hashable, List[List[float]]], Dict[Hashable, List[List[float]]]]:
# Instance features
assert sample.after_load is not None

View File

@@ -145,6 +145,7 @@ class PrimalSolutionComponent(Component):
self,
_: Optional[Instance],
sample: Sample,
pre: Optional[List[Any]] = None,
) -> Tuple[Dict[Category, List[List[float]]], Dict[Category, List[List[float]]]]:
x: Dict = {}
y: Dict = {}

View File

@@ -154,6 +154,7 @@ class StaticLazyConstraintsComponent(Component):
self,
_: Optional[Instance],
sample: Sample,
pre: Optional[List[Any]] = None,
) -> Tuple[Dict[Hashable, List[List[float]]], Dict[Hashable, List[List[float]]]]:
x, y, __ = self._sample_xy_with_cids(sample)
return x, y