Make xy_sample receive features, not instances

This commit is contained in:
2021-03-31 09:57:57 -05:00
parent 8fc9979b37
commit fe7bad885c
12 changed files with 158 additions and 119 deletions

View File

@@ -3,11 +3,11 @@
# Released under the modified BSD license. See COPYING.md for more details.
import numpy as np
from typing import Any, List, Union, TYPE_CHECKING, Tuple, Dict
from typing import Any, List, Union, TYPE_CHECKING, Tuple, Dict, Optional
from miplearn.extractors import InstanceIterator
from miplearn.instance import Instance
from miplearn.types import LearningSolveStats, TrainingSample
from miplearn.types import LearningSolveStats, TrainingSample, Features
if TYPE_CHECKING:
from miplearn.solvers.learning import LearningSolver
@@ -133,14 +133,16 @@ class Component:
@staticmethod
def xy_sample(
instance: Any,
training_sample: TrainingSample,
) -> Tuple[Dict, Dict]:
features: Features,
sample: TrainingSample,
) -> Optional[Tuple[Dict, Dict]]:
"""
Given a training sample, returns a pair of x and y dictionaries containing,
respectively, the matrices of ML features and the labels for the sample.
Given a set of features and a training sample, returns a pair of x and y
dictionaries containing, respectively, the matrices of ML features and the
labels for the sample. If the training sample does not include label
information, returns None.
"""
return {}, {}
return None
def xy_instances(
self,
@@ -149,8 +151,12 @@ class Component:
x_combined: Dict = {}
y_combined: Dict = {}
for instance in InstanceIterator(instances):
assert isinstance(instance, Instance)
for sample in instance.training_data:
x_sample, y_sample = self.xy_sample(instance, sample)
xy = self.xy_sample(instance.features, sample)
if xy is None:
continue
x_sample, y_sample = xy
for cat in x_sample.keys():
if cat not in x_combined:
x_combined[cat] = []

View File

@@ -5,14 +5,14 @@
import logging
import sys
from copy import deepcopy
from typing import Any, Dict, Tuple
from typing import Any, Dict, Tuple, Optional
import numpy as np
from tqdm.auto import tqdm
from miplearn.classifiers.counting import CountingClassifier
from miplearn.components.component import Component
from miplearn.types import TrainingSample
from miplearn.types import TrainingSample, Features
logger = logging.getLogger(__name__)
@@ -207,15 +207,16 @@ class StaticLazyConstraintsComponent(Component):
@staticmethod
def xy_sample(
instance: Any,
features: Features,
sample: TrainingSample,
) -> Tuple[Dict, Dict]:
) -> Optional[Tuple[Dict, Dict]]:
if "LazyStatic: Enforced" not in sample:
return None
x: Dict = {}
y: Dict = {}
if "LazyStatic: All" not in sample:
return x, y
for cid in sorted(sample["LazyStatic: All"]):
cfeatures = instance.features["Constraints"][cid]
for (cid, cfeatures) in features["Constraints"].items():
if not cfeatures["Lazy"]:
continue
category = cfeatures["Category"]
if category is None:
continue

View File

@@ -19,7 +19,7 @@ from miplearn.classifiers import Regressor
from miplearn.components.component import Component
from miplearn.extractors import InstanceIterator
from miplearn.instance import Instance
from miplearn.types import MIPSolveStats, TrainingSample, LearningSolveStats
from miplearn.types import MIPSolveStats, TrainingSample, LearningSolveStats, Features
if TYPE_CHECKING:
from miplearn.solvers.learning import LearningSolver
@@ -164,18 +164,20 @@ class ObjectiveValueComponent(Component):
@staticmethod
def xy_sample(
instance: Any,
features: Features,
sample: TrainingSample,
) -> Tuple[Dict, Dict]:
x: Dict = {}
y: Dict = {}
) -> Optional[Tuple[Dict, Dict]]:
if "Lower bound" not in sample:
return x, y
features = instance.features["Instance"]["User features"]
return None
f = features["Instance"]["User features"]
if "LP value" in sample and sample["LP value"] is not None:
features += [sample["LP value"]]
x["Lower bound"] = [features]
x["Upper bound"] = [features]
y["Lower bound"] = [[sample["Lower bound"]]]
y["Upper bound"] = [[sample["Upper bound"]]]
f += [sample["LP value"]]
x = {
"Lower bound": [f],
"Upper bound": [f],
}
y = {
"Lower bound": [[sample["Lower bound"]]],
"Upper bound": [[sample["Upper bound"]]],
}
return x, y

View File

@@ -211,15 +211,15 @@ class PrimalSolutionComponent(Component):
@staticmethod
def xy_sample(
instance: Any,
features: Features,
sample: TrainingSample,
) -> Tuple[Dict, Dict]:
) -> Optional[Tuple[Dict, Dict]]:
if "Solution" not in sample:
return {}, {}
return None
assert sample["Solution"] is not None
return cast(
Tuple[Dict, Dict],
PrimalSolutionComponent._extract(instance.features, sample),
PrimalSolutionComponent._extract(features, sample),
)
@staticmethod
@@ -227,7 +227,10 @@ class PrimalSolutionComponent(Component):
features: Features,
sample: TrainingSample,
) -> Dict:
return cast(Dict, PrimalSolutionComponent._extract(features, sample))
return cast(
Dict,
PrimalSolutionComponent._extract(features, sample),
)
@staticmethod
def _extract(