mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-07 18:08:51 -06:00
Replace InstanceIterator by PickleGzInstance
This commit is contained in:
@@ -2,10 +2,10 @@
|
||||
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
import numpy as np
|
||||
from typing import Any, List, Union, TYPE_CHECKING, Tuple, Dict, Optional, Hashable
|
||||
from typing import Any, List, TYPE_CHECKING, Tuple, Dict, Hashable
|
||||
|
||||
import numpy as np
|
||||
|
||||
from miplearn.extractors import InstanceIterator
|
||||
from miplearn.instance import Instance
|
||||
from miplearn.types import LearningSolveStats, TrainingSample, Features
|
||||
|
||||
@@ -120,11 +120,11 @@ class Component:
|
||||
|
||||
def xy_instances(
|
||||
self,
|
||||
instances: Union[List[str], List[Instance]],
|
||||
instances: List[Instance],
|
||||
) -> Tuple[Dict, Dict]:
|
||||
x_combined: Dict = {}
|
||||
y_combined: Dict = {}
|
||||
for instance in InstanceIterator(instances):
|
||||
for instance in instances:
|
||||
assert isinstance(instance, Instance)
|
||||
for sample in instance.training_data:
|
||||
xy = self.sample_xy(instance.features, sample)
|
||||
@@ -141,7 +141,7 @@ class Component:
|
||||
|
||||
def fit(
|
||||
self,
|
||||
training_instances: Union[List[str], List[Instance]],
|
||||
training_instances: List[Instance],
|
||||
) -> None:
|
||||
x, y = self.xy_instances(training_instances)
|
||||
for cat in x.keys():
|
||||
@@ -198,9 +198,9 @@ class Component:
|
||||
) -> None:
|
||||
return
|
||||
|
||||
def evaluate(self, instances: Union[List[str], List[Instance]]) -> List:
|
||||
def evaluate(self, instances: List[Instance]) -> List:
|
||||
ev = []
|
||||
for instance in InstanceIterator(instances):
|
||||
for instance in instances:
|
||||
for sample in instance.training_data:
|
||||
ev += [self.sample_evaluate(instance.features, sample)]
|
||||
return ev
|
||||
|
||||
Reference in New Issue
Block a user