|
|
@ -11,6 +11,8 @@ from miplearn.features import Sample
|
|
|
|
from miplearn.instance.base import Instance
|
|
|
|
from miplearn.instance.base import Instance
|
|
|
|
from miplearn.types import LearningSolveStats
|
|
|
|
from miplearn.types import LearningSolveStats
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from p_tqdm import p_umap
|
|
|
|
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
from miplearn.solvers.learning import LearningSolver
|
|
|
|
from miplearn.solvers.learning import LearningSolver
|
|
|
|
|
|
|
|
|
|
|
@ -159,7 +161,6 @@ class Component(EnforceOverrides):
|
|
|
|
self,
|
|
|
|
self,
|
|
|
|
instance: Optional[Instance],
|
|
|
|
instance: Optional[Instance],
|
|
|
|
sample: Sample,
|
|
|
|
sample: Sample,
|
|
|
|
pre: Optional[List[Any]] = None,
|
|
|
|
|
|
|
|
) -> Tuple[Dict, Dict]:
|
|
|
|
) -> Tuple[Dict, Dict]:
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Returns a pair of x and y dictionaries containing, respectively, the matrices
|
|
|
|
Returns a pair of x and y dictionaries containing, respectively, the matrices
|
|
|
@ -168,6 +169,9 @@ class Component(EnforceOverrides):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
pass
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def pre_fit(self, pre: List[Any]):
|
|
|
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
def user_cut_cb(
|
|
|
|
def user_cut_cb(
|
|
|
|
self,
|
|
|
|
self,
|
|
|
|
solver: "LearningSolver",
|
|
|
|
solver: "LearningSolver",
|
|
|
@ -183,6 +187,7 @@ class Component(EnforceOverrides):
|
|
|
|
def fit_multiple(
|
|
|
|
def fit_multiple(
|
|
|
|
components: Dict[str, "Component"],
|
|
|
|
components: Dict[str, "Component"],
|
|
|
|
instances: List[Instance],
|
|
|
|
instances: List[Instance],
|
|
|
|
|
|
|
|
n_jobs: int = 1,
|
|
|
|
) -> None:
|
|
|
|
) -> None:
|
|
|
|
def _pre_sample_xy(instance: Instance) -> Dict:
|
|
|
|
def _pre_sample_xy(instance: Instance) -> Dict:
|
|
|
|
pre_instance: Dict = {}
|
|
|
|
pre_instance: Dict = {}
|
|
|
@ -195,7 +200,17 @@ class Component(EnforceOverrides):
|
|
|
|
instance.free()
|
|
|
|
instance.free()
|
|
|
|
return pre_instance
|
|
|
|
return pre_instance
|
|
|
|
|
|
|
|
|
|
|
|
def _sample_xy(instance: Instance, pre: Dict) -> Tuple[Dict, Dict]:
|
|
|
|
pre = p_umap(_pre_sample_xy, instances, num_cpus=n_jobs)
|
|
|
|
|
|
|
|
pre_combined: Dict = {}
|
|
|
|
|
|
|
|
for (cname, comp) in components.items():
|
|
|
|
|
|
|
|
pre_combined[cname] = []
|
|
|
|
|
|
|
|
for p in pre:
|
|
|
|
|
|
|
|
pre_combined[cname].extend(p[cname])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (cname, comp) in components.items():
|
|
|
|
|
|
|
|
comp.pre_fit(pre_combined[cname])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _sample_xy(instance: Instance) -> Tuple[Dict, Dict]:
|
|
|
|
x_instance: Dict = {}
|
|
|
|
x_instance: Dict = {}
|
|
|
|
y_instance: Dict = {}
|
|
|
|
y_instance: Dict = {}
|
|
|
|
for (cname, comp) in components.items():
|
|
|
|
for (cname, comp) in components.items():
|
|
|
@ -206,7 +221,7 @@ class Component(EnforceOverrides):
|
|
|
|
for (cname, comp) in components.items():
|
|
|
|
for (cname, comp) in components.items():
|
|
|
|
x = x_instance[cname]
|
|
|
|
x = x_instance[cname]
|
|
|
|
y = y_instance[cname]
|
|
|
|
y = y_instance[cname]
|
|
|
|
x_sample, y_sample = comp.sample_xy(instance, sample, pre[cname])
|
|
|
|
x_sample, y_sample = comp.sample_xy(instance, sample)
|
|
|
|
for cat in x_sample.keys():
|
|
|
|
for cat in x_sample.keys():
|
|
|
|
if cat not in x:
|
|
|
|
if cat not in x:
|
|
|
|
x[cat] = []
|
|
|
|
x[cat] = []
|
|
|
@ -216,15 +231,7 @@ class Component(EnforceOverrides):
|
|
|
|
instance.free()
|
|
|
|
instance.free()
|
|
|
|
return x_instance, y_instance
|
|
|
|
return x_instance, y_instance
|
|
|
|
|
|
|
|
|
|
|
|
pre = [_pre_sample_xy(instance) for instance in instances]
|
|
|
|
xy = p_umap(_sample_xy, instances)
|
|
|
|
|
|
|
|
|
|
|
|
pre_combined: Dict = {}
|
|
|
|
|
|
|
|
for (cname, comp) in components.items():
|
|
|
|
|
|
|
|
pre_combined[cname] = []
|
|
|
|
|
|
|
|
for p in pre:
|
|
|
|
|
|
|
|
pre_combined[cname].extend(p[cname])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
xy = [_sample_xy(instances, pre_combined) for instances in instances]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (cname, comp) in components.items():
|
|
|
|
for (cname, comp) in components.items():
|
|
|
|
x_comp: Dict = {}
|
|
|
|
x_comp: Dict = {}
|
|
|
|