Train without loading all instances to memory

This commit is contained in:
2020-12-04 09:37:41 -06:00
parent f03cc15b75
commit 388b10c63c
5 changed files with 66 additions and 25 deletions

View File

@@ -53,7 +53,6 @@ class PrimalSolutionComponent(Component):
for category in tqdm(features.keys(),
desc="Fit (primal)",
disable=not sys.stdout.isatty(),
):
x_train = features[category]
for label in [0, 1]:
@@ -110,7 +109,6 @@ class PrimalSolutionComponent(Component):
"Fix one": {}}
for instance_idx in tqdm(range(len(instances)),
desc="Evaluate (primal)",
disable=not sys.stdout.isatty(),
):
instance = instances[instance_idx]
solution_actual = instance.solution

View File

@@ -8,10 +8,11 @@ from copy import deepcopy
import numpy as np
from miplearn.components import classifier_evaluation_dict
from tqdm import tqdm
from tqdm.auto import tqdm
from miplearn import Component
from miplearn.classifiers.counting import CountingClassifier
from miplearn.extractors import InstanceIterator
logger = logging.getLogger(__name__)
@@ -60,16 +61,12 @@ class RelaxationComponent(Component):
instance.slacks = solver.internal_solver.get_constraint_slacks()
def fit(self, training_instances):
training_instances = [instance
for instance in training_instances
if hasattr(instance, "slacks")]
logger.debug("Extracting x and y...")
x = self.x(training_instances)
y = self.y(training_instances)
logger.debug("Fitting...")
for category in tqdm(x.keys(),
desc="Fit (relaxation)",
disable=not sys.stdout.isatty()):
desc="Fit (relaxation)"):
if category not in self.classifiers:
self.classifiers[category] = deepcopy(self.classifier_prototype)
self.classifiers[category].fit(x[category], y[category])
@@ -80,7 +77,9 @@ class RelaxationComponent(Component):
return_constraints=False):
x = {}
constraints = {}
for instance in instances:
for instance in tqdm(InstanceIterator(instances),
desc="Extract (relaxation:x)",
disable=len(instances) < 5):
if constraint_ids is not None:
cids = constraint_ids
else:
@@ -101,7 +100,9 @@ class RelaxationComponent(Component):
def y(self, instances):
y = {}
for instance in instances:
for instance in tqdm(InstanceIterator(instances),
desc="Extract (relaxation:y)",
disable=len(instances) < 5):
for (cid, slack) in instance.slacks.items():
category = instance.get_constraint_category(cid)
if category is None:
@@ -120,7 +121,7 @@ class RelaxationComponent(Component):
if category not in self.classifiers:
continue
y[category] = []
#x_cat = np.array(x_cat)
x_cat = np.array(x_cat)
proba = self.classifiers[category].predict_proba(x_cat)
for i in range(len(proba)):
if proba[i][1] >= self.threshold: