mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Add pre argument to sample_xy
This commit is contained in:
@@ -159,6 +159,7 @@ class Component(EnforceOverrides):
|
|||||||
self,
|
self,
|
||||||
instance: Optional[Instance],
|
instance: Optional[Instance],
|
||||||
sample: Sample,
|
sample: Sample,
|
||||||
|
pre: Optional[List[Any]] = None,
|
||||||
) -> Tuple[Dict, Dict]:
|
) -> Tuple[Dict, Dict]:
|
||||||
"""
|
"""
|
||||||
Returns a pair of x and y dictionaries containing, respectively, the matrices
|
Returns a pair of x and y dictionaries containing, respectively, the matrices
|
||||||
@@ -175,7 +176,7 @@ class Component(EnforceOverrides):
|
|||||||
) -> None:
|
) -> None:
|
||||||
return
|
return
|
||||||
|
|
||||||
def pre_sample_xy(self, instance: Instance, sample: Sample) -> None:
|
def pre_sample_xy(self, instance: Instance, sample: Sample) -> Any:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -183,28 +184,29 @@ class Component(EnforceOverrides):
|
|||||||
components: Dict[str, "Component"],
|
components: Dict[str, "Component"],
|
||||||
instances: List[Instance],
|
instances: List[Instance],
|
||||||
) -> None:
|
) -> None:
|
||||||
x_combined: Dict = {}
|
def _pre_sample_xy(instance: Instance) -> Dict:
|
||||||
y_combined: Dict = {}
|
pre_instance: Dict = {}
|
||||||
for (cname, comp) in components.items():
|
for (cname, comp) in components.items():
|
||||||
x_combined[cname] = {}
|
pre_instance[cname] = []
|
||||||
y_combined[cname] = {}
|
|
||||||
|
|
||||||
# pre_sample_xy
|
|
||||||
for instance in instances:
|
|
||||||
instance.load()
|
instance.load()
|
||||||
for sample in instance.samples:
|
for sample in instance.samples:
|
||||||
for (cname, comp) in components.items():
|
for (cname, comp) in components.items():
|
||||||
comp.pre_sample_xy(instance, sample)
|
pre_instance[cname].append(comp.pre_sample_xy(instance, sample))
|
||||||
instance.free()
|
instance.free()
|
||||||
|
return pre_instance
|
||||||
|
|
||||||
# sample_xy
|
def _sample_xy(instance: Instance, pre: Dict) -> Tuple[Dict, Dict]:
|
||||||
for instance in instances:
|
x_instance: Dict = {}
|
||||||
|
y_instance: Dict = {}
|
||||||
|
for (cname, comp) in components.items():
|
||||||
|
x_instance[cname] = {}
|
||||||
|
y_instance[cname] = {}
|
||||||
instance.load()
|
instance.load()
|
||||||
for sample in instance.samples:
|
for sample in instance.samples:
|
||||||
for (cname, comp) in components.items():
|
for (cname, comp) in components.items():
|
||||||
x = x_combined[cname]
|
x = x_instance[cname]
|
||||||
y = y_combined[cname]
|
y = y_instance[cname]
|
||||||
x_sample, y_sample = comp.sample_xy(instance, sample)
|
x_sample, y_sample = comp.sample_xy(instance, sample, pre[cname])
|
||||||
for cat in x_sample.keys():
|
for cat in x_sample.keys():
|
||||||
if cat not in x:
|
if cat not in x:
|
||||||
x[cat] = []
|
x[cat] = []
|
||||||
@@ -212,12 +214,29 @@ class Component(EnforceOverrides):
|
|||||||
x[cat] += x_sample[cat]
|
x[cat] += x_sample[cat]
|
||||||
y[cat] += y_sample[cat]
|
y[cat] += y_sample[cat]
|
||||||
instance.free()
|
instance.free()
|
||||||
|
return x_instance, y_instance
|
||||||
|
|
||||||
# fit_xy
|
pre = [_pre_sample_xy(instance) for instance in instances]
|
||||||
|
|
||||||
|
pre_combined: Dict = {}
|
||||||
for (cname, comp) in components.items():
|
for (cname, comp) in components.items():
|
||||||
x = x_combined[cname]
|
pre_combined[cname] = []
|
||||||
y = y_combined[cname]
|
for p in pre:
|
||||||
for cat in x.keys():
|
pre_combined[cname].extend(p[cname])
|
||||||
x[cat] = np.array(x[cat], dtype=np.float32)
|
|
||||||
y[cat] = np.array(y[cat])
|
xy = [_sample_xy(instances, pre_combined) for instances in instances]
|
||||||
comp.fit_xy(x, y)
|
|
||||||
|
for (cname, comp) in components.items():
|
||||||
|
x_comp: Dict = {}
|
||||||
|
y_comp: Dict = {}
|
||||||
|
for (x, y) in xy:
|
||||||
|
for cat in x[cname].keys():
|
||||||
|
if cat not in x_comp:
|
||||||
|
x_comp[cat] = []
|
||||||
|
y_comp[cat] = []
|
||||||
|
x_comp[cat].extend(x[cname][cat])
|
||||||
|
y_comp[cat].extend(y[cname][cat])
|
||||||
|
for cat in x_comp.keys():
|
||||||
|
x_comp[cat] = np.array(x_comp[cat], dtype=np.float32)
|
||||||
|
y_comp[cat] = np.array(y_comp[cat])
|
||||||
|
comp.fit_xy(x_comp, y_comp)
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
# Released under the modified BSD license. See COPYING.md for more details.
|
# Released under the modified BSD license. See COPYING.md for more details.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, Hashable, List, Tuple, Optional
|
from typing import Dict, Hashable, List, Tuple, Optional, Any, FrozenSet, Set
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from overrides import overrides
|
from overrides import overrides
|
||||||
@@ -89,7 +89,14 @@ class DynamicConstraintsComponent(Component):
|
|||||||
self,
|
self,
|
||||||
instance: Optional[Instance],
|
instance: Optional[Instance],
|
||||||
sample: Sample,
|
sample: Sample,
|
||||||
|
pre: Optional[List[Any]] = None,
|
||||||
) -> Tuple[Dict, Dict]:
|
) -> Tuple[Dict, Dict]:
|
||||||
|
assert pre is not None
|
||||||
|
known_cids: Set = set()
|
||||||
|
for cids in pre:
|
||||||
|
known_cids |= cids
|
||||||
|
self.known_cids.clear()
|
||||||
|
self.known_cids.extend(sorted(known_cids))
|
||||||
x, y, _ = self.sample_xy_with_cids(instance, sample)
|
x, y, _ = self.sample_xy_with_cids(instance, sample)
|
||||||
return x, y
|
return x, y
|
||||||
|
|
||||||
@@ -117,14 +124,14 @@ class DynamicConstraintsComponent(Component):
|
|||||||
return pred
|
return pred
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def pre_sample_xy(self, instance: Instance, sample: Sample) -> None:
|
def pre_sample_xy(self, instance: Instance, sample: Sample) -> Any:
|
||||||
if (
|
if (
|
||||||
sample.after_mip is None
|
sample.after_mip is None
|
||||||
or sample.after_mip.extra is None
|
or sample.after_mip.extra is None
|
||||||
or sample.after_mip.extra[self.attr] is None
|
or sample.after_mip.extra[self.attr] is None
|
||||||
):
|
):
|
||||||
return
|
return
|
||||||
self.known_cids.extend(sorted(sample.after_mip.extra[self.attr]))
|
return sample.after_mip.extra[self.attr]
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def fit_xy(
|
def fit_xy(
|
||||||
|
|||||||
@@ -108,8 +108,9 @@ class DynamicLazyConstraintsComponent(Component):
|
|||||||
self,
|
self,
|
||||||
instance: Optional[Instance],
|
instance: Optional[Instance],
|
||||||
sample: Sample,
|
sample: Sample,
|
||||||
|
pre: Optional[List[Any]] = None,
|
||||||
) -> Tuple[Dict, Dict]:
|
) -> Tuple[Dict, Dict]:
|
||||||
return self.dynamic.sample_xy(instance, sample)
|
return self.dynamic.sample_xy(instance, sample, pre=pre)
|
||||||
|
|
||||||
def sample_predict(
|
def sample_predict(
|
||||||
self,
|
self,
|
||||||
@@ -119,8 +120,8 @@ class DynamicLazyConstraintsComponent(Component):
|
|||||||
return self.dynamic.sample_predict(instance, sample)
|
return self.dynamic.sample_predict(instance, sample)
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def pre_sample_xy(self, instance: Instance, sample: Sample) -> None:
|
def pre_sample_xy(self, instance: Instance, sample: Sample) -> Any:
|
||||||
self.dynamic.pre_sample_xy(instance, sample)
|
return self.dynamic.pre_sample_xy(instance, sample)
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def fit_xy(
|
def fit_xy(
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
# Released under the modified BSD license. See COPYING.md for more details.
|
# Released under the modified BSD license. See COPYING.md for more details.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, TYPE_CHECKING, Hashable, Set, Tuple, Dict, List
|
from typing import Any, TYPE_CHECKING, Hashable, Set, Tuple, Dict, List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from overrides import overrides
|
from overrides import overrides
|
||||||
@@ -101,8 +101,9 @@ class UserCutsComponent(Component):
|
|||||||
self,
|
self,
|
||||||
instance: "Instance",
|
instance: "Instance",
|
||||||
sample: Sample,
|
sample: Sample,
|
||||||
|
pre: Optional[List[Any]] = None,
|
||||||
) -> Tuple[Dict, Dict]:
|
) -> Tuple[Dict, Dict]:
|
||||||
return self.dynamic.sample_xy(instance, sample)
|
return self.dynamic.sample_xy(instance, sample, pre=pre)
|
||||||
|
|
||||||
def sample_predict(
|
def sample_predict(
|
||||||
self,
|
self,
|
||||||
@@ -112,8 +113,8 @@ class UserCutsComponent(Component):
|
|||||||
return self.dynamic.sample_predict(instance, sample)
|
return self.dynamic.sample_predict(instance, sample)
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def pre_sample_xy(self, instance: Instance, sample: Sample) -> None:
|
def pre_sample_xy(self, instance: Instance, sample: Sample) -> Any:
|
||||||
self.dynamic.pre_sample_xy(instance, sample)
|
return self.dynamic.pre_sample_xy(instance, sample)
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def fit_xy(
|
def fit_xy(
|
||||||
|
|||||||
@@ -76,6 +76,7 @@ class ObjectiveValueComponent(Component):
|
|||||||
self,
|
self,
|
||||||
_: Optional[Instance],
|
_: Optional[Instance],
|
||||||
sample: Sample,
|
sample: Sample,
|
||||||
|
pre: Optional[List[Any]] = None,
|
||||||
) -> Tuple[Dict[Hashable, List[List[float]]], Dict[Hashable, List[List[float]]]]:
|
) -> Tuple[Dict[Hashable, List[List[float]]], Dict[Hashable, List[List[float]]]]:
|
||||||
# Instance features
|
# Instance features
|
||||||
assert sample.after_load is not None
|
assert sample.after_load is not None
|
||||||
|
|||||||
@@ -145,6 +145,7 @@ class PrimalSolutionComponent(Component):
|
|||||||
self,
|
self,
|
||||||
_: Optional[Instance],
|
_: Optional[Instance],
|
||||||
sample: Sample,
|
sample: Sample,
|
||||||
|
pre: Optional[List[Any]] = None,
|
||||||
) -> Tuple[Dict[Category, List[List[float]]], Dict[Category, List[List[float]]]]:
|
) -> Tuple[Dict[Category, List[List[float]]], Dict[Category, List[List[float]]]]:
|
||||||
x: Dict = {}
|
x: Dict = {}
|
||||||
y: Dict = {}
|
y: Dict = {}
|
||||||
|
|||||||
@@ -154,6 +154,7 @@ class StaticLazyConstraintsComponent(Component):
|
|||||||
self,
|
self,
|
||||||
_: Optional[Instance],
|
_: Optional[Instance],
|
||||||
sample: Sample,
|
sample: Sample,
|
||||||
|
pre: Optional[List[Any]] = None,
|
||||||
) -> Tuple[Dict[Hashable, List[List[float]]], Dict[Hashable, List[List[float]]]]:
|
) -> Tuple[Dict[Hashable, List[List[float]]], Dict[Hashable, List[List[float]]]]:
|
||||||
x, y, __ = self._sample_xy_with_cids(sample)
|
x, y, __ = self._sample_xy_with_cids(sample)
|
||||||
return x, y
|
return x, y
|
||||||
|
|||||||
@@ -87,7 +87,6 @@ def training_instances() -> List[Instance]:
|
|||||||
|
|
||||||
def test_sample_xy(training_instances: List[Instance]) -> None:
|
def test_sample_xy(training_instances: List[Instance]) -> None:
|
||||||
comp = DynamicLazyConstraintsComponent()
|
comp = DynamicLazyConstraintsComponent()
|
||||||
comp.dynamic.known_cids = ["c1", "c2", "c3", "c4"]
|
|
||||||
x_expected = {
|
x_expected = {
|
||||||
"type-a": [[5.0, 1.0, 2.0, 3.0], [5.0, 4.0, 5.0, 6.0]],
|
"type-a": [[5.0, 1.0, 2.0, 3.0], [5.0, 4.0, 5.0, 6.0]],
|
||||||
"type-b": [[5.0, 1.0, 2.0], [5.0, 3.0, 4.0]],
|
"type-b": [[5.0, 1.0, 2.0], [5.0, 3.0, 4.0]],
|
||||||
@@ -99,6 +98,7 @@ def test_sample_xy(training_instances: List[Instance]) -> None:
|
|||||||
x_actual, y_actual = comp.sample_xy(
|
x_actual, y_actual = comp.sample_xy(
|
||||||
training_instances[0],
|
training_instances[0],
|
||||||
training_instances[0].samples[0],
|
training_instances[0].samples[0],
|
||||||
|
pre=[{"c1", "c2", "c3", "c4"}],
|
||||||
)
|
)
|
||||||
assert_equals(x_actual, x_expected)
|
assert_equals(x_actual, x_expected)
|
||||||
assert_equals(y_actual, y_expected)
|
assert_equals(y_actual, y_expected)
|
||||||
|
|||||||
Reference in New Issue
Block a user