Replace InstanceIterator by PickleGzInstance

This commit is contained in:
2021-04-04 14:48:46 -05:00
parent b4770c6c0a
commit 08e808690e
14 changed files with 253 additions and 257 deletions

View File

@@ -13,7 +13,6 @@ from miplearn.classifiers.counting import CountingClassifier
from miplearn.components import classifier_evaluation_dict
from miplearn.components.component import Component
from miplearn.components.steps.drop_redundant import DropRedundantInequalitiesStep
from miplearn.extractors import InstanceIterator
logger = logging.getLogger(__name__)
@@ -116,7 +115,7 @@ class ConvertTightIneqsIntoEqsStep(Component):
def _x_train(instances):
x = {}
for instance in tqdm(
InstanceIterator(instances),
instances,
desc="Extract (drop:x)",
disable=len(instances) < 5,
):
@@ -139,7 +138,7 @@ class ConvertTightIneqsIntoEqsStep(Component):
def y(self, instances):
y = {}
for instance in tqdm(
InstanceIterator(instances),
instances,
desc="Extract (rlx:conv_ineqs:y)",
disable=len(instances) < 5,
):

View File

@@ -6,14 +6,13 @@ import logging
from copy import deepcopy
import numpy as np
from tqdm import tqdm
from p_tqdm import p_umap
from tqdm import tqdm
from miplearn.classifiers.counting import CountingClassifier
from miplearn.components import classifier_evaluation_dict
from miplearn.components.component import Component
from miplearn.components.lazy_static import LazyConstraint
from miplearn.extractors import InstanceIterator
logger = logging.getLogger(__name__)
@@ -131,31 +130,24 @@ class DropRedundantInequalitiesStep(Component):
def _extract(instance):
x = {}
y = {}
for instance in InstanceIterator([instance]):
for training_data in instance.training_data:
for (cid, slack) in training_data["slacks"].items():
category = instance.get_constraint_category(cid)
if category is None:
continue
if category not in x:
x[category] = []
if category not in y:
y[category] = []
if slack > self.slack_tolerance:
y[category] += [[False, True]]
else:
y[category] += [[True, False]]
x[category] += [instance.get_constraint_features(cid)]
for training_data in instance.training_data:
for (cid, slack) in training_data["slacks"].items():
category = instance.get_constraint_category(cid)
if category is None:
continue
if category not in x:
x[category] = []
if category not in y:
y[category] = []
if slack > self.slack_tolerance:
y[category] += [[False, True]]
else:
y[category] += [[True, False]]
x[category] += [instance.get_constraint_features(cid)]
return x, y
if n_jobs == 1:
results = [
_extract(i)
for i in tqdm(
instances,
desc="Extract (drop 1/3)",
)
]
results = [_extract(i) for i in tqdm(instances, desc="Extract (drop 1/3)")]
else:
results = p_umap(
_extract,