mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Replace InstanceIterator by PickleGzInstance
This commit is contained in:
@@ -13,7 +13,6 @@ from miplearn.classifiers.counting import CountingClassifier
|
||||
from miplearn.components import classifier_evaluation_dict
|
||||
from miplearn.components.component import Component
|
||||
from miplearn.components.steps.drop_redundant import DropRedundantInequalitiesStep
|
||||
from miplearn.extractors import InstanceIterator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -116,7 +115,7 @@ class ConvertTightIneqsIntoEqsStep(Component):
|
||||
def _x_train(instances):
|
||||
x = {}
|
||||
for instance in tqdm(
|
||||
InstanceIterator(instances),
|
||||
instances,
|
||||
desc="Extract (drop:x)",
|
||||
disable=len(instances) < 5,
|
||||
):
|
||||
@@ -139,7 +138,7 @@ class ConvertTightIneqsIntoEqsStep(Component):
|
||||
def y(self, instances):
|
||||
y = {}
|
||||
for instance in tqdm(
|
||||
InstanceIterator(instances),
|
||||
instances,
|
||||
desc="Extract (rlx:conv_ineqs:y)",
|
||||
disable=len(instances) < 5,
|
||||
):
|
||||
|
||||
@@ -6,14 +6,13 @@ import logging
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from p_tqdm import p_umap
|
||||
from tqdm import tqdm
|
||||
|
||||
from miplearn.classifiers.counting import CountingClassifier
|
||||
from miplearn.components import classifier_evaluation_dict
|
||||
from miplearn.components.component import Component
|
||||
from miplearn.components.lazy_static import LazyConstraint
|
||||
from miplearn.extractors import InstanceIterator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -131,31 +130,24 @@ class DropRedundantInequalitiesStep(Component):
|
||||
def _extract(instance):
|
||||
x = {}
|
||||
y = {}
|
||||
for instance in InstanceIterator([instance]):
|
||||
for training_data in instance.training_data:
|
||||
for (cid, slack) in training_data["slacks"].items():
|
||||
category = instance.get_constraint_category(cid)
|
||||
if category is None:
|
||||
continue
|
||||
if category not in x:
|
||||
x[category] = []
|
||||
if category not in y:
|
||||
y[category] = []
|
||||
if slack > self.slack_tolerance:
|
||||
y[category] += [[False, True]]
|
||||
else:
|
||||
y[category] += [[True, False]]
|
||||
x[category] += [instance.get_constraint_features(cid)]
|
||||
for training_data in instance.training_data:
|
||||
for (cid, slack) in training_data["slacks"].items():
|
||||
category = instance.get_constraint_category(cid)
|
||||
if category is None:
|
||||
continue
|
||||
if category not in x:
|
||||
x[category] = []
|
||||
if category not in y:
|
||||
y[category] = []
|
||||
if slack > self.slack_tolerance:
|
||||
y[category] += [[False, True]]
|
||||
else:
|
||||
y[category] += [[True, False]]
|
||||
x[category] += [instance.get_constraint_features(cid)]
|
||||
return x, y
|
||||
|
||||
if n_jobs == 1:
|
||||
results = [
|
||||
_extract(i)
|
||||
for i in tqdm(
|
||||
instances,
|
||||
desc="Extract (drop 1/3)",
|
||||
)
|
||||
]
|
||||
results = [_extract(i) for i in tqdm(instances, desc="Extract (drop 1/3)")]
|
||||
else:
|
||||
results = p_umap(
|
||||
_extract,
|
||||
|
||||
Reference in New Issue
Block a user