mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-07 18:08:51 -06:00
Make xy_sample receive features, not instances
This commit is contained in:
@@ -3,11 +3,11 @@
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
import numpy as np
|
||||
from typing import Any, List, Union, TYPE_CHECKING, Tuple, Dict
|
||||
from typing import Any, List, Union, TYPE_CHECKING, Tuple, Dict, Optional
|
||||
|
||||
from miplearn.extractors import InstanceIterator
|
||||
from miplearn.instance import Instance
|
||||
from miplearn.types import LearningSolveStats, TrainingSample
|
||||
from miplearn.types import LearningSolveStats, TrainingSample, Features
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from miplearn.solvers.learning import LearningSolver
|
||||
@@ -133,14 +133,16 @@ class Component:
|
||||
|
||||
@staticmethod
|
||||
def xy_sample(
|
||||
instance: Any,
|
||||
training_sample: TrainingSample,
|
||||
) -> Tuple[Dict, Dict]:
|
||||
features: Features,
|
||||
sample: TrainingSample,
|
||||
) -> Optional[Tuple[Dict, Dict]]:
|
||||
"""
|
||||
Given a training sample, returns a pair of x and y dictionaries containing,
|
||||
respectively, the matrices of ML features and the labels for the sample.
|
||||
Given a set of features and a training sample, returns a pair of x and y
|
||||
dictionaries containing, respectively, the matrices of ML features and the
|
||||
labels for the sample. If the training sample does not include label
|
||||
information, returns None.
|
||||
"""
|
||||
return {}, {}
|
||||
return None
|
||||
|
||||
def xy_instances(
|
||||
self,
|
||||
@@ -149,8 +151,12 @@ class Component:
|
||||
x_combined: Dict = {}
|
||||
y_combined: Dict = {}
|
||||
for instance in InstanceIterator(instances):
|
||||
assert isinstance(instance, Instance)
|
||||
for sample in instance.training_data:
|
||||
x_sample, y_sample = self.xy_sample(instance, sample)
|
||||
xy = self.xy_sample(instance.features, sample)
|
||||
if xy is None:
|
||||
continue
|
||||
x_sample, y_sample = xy
|
||||
for cat in x_sample.keys():
|
||||
if cat not in x_combined:
|
||||
x_combined[cat] = []
|
||||
|
||||
Reference in New Issue
Block a user