mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-07 09:58:51 -06:00
Make progress bar labels consistent
This commit is contained in:
@@ -104,7 +104,7 @@ class BranchPriorityComponent(Component):
|
|||||||
# Run strong branching on pending instances
|
# Run strong branching on pending instances
|
||||||
subcomponents = Parallel(n_jobs=n_jobs)(
|
subcomponents = Parallel(n_jobs=n_jobs)(
|
||||||
delayed(_process)(instance)
|
delayed(_process)(instance)
|
||||||
for instance in tqdm(self.pending_instances, desc="Branch priority")
|
for instance in tqdm(self.pending_instances, desc="Fit (branch)")
|
||||||
)
|
)
|
||||||
self.merge(subcomponents)
|
self.merge(subcomponents)
|
||||||
self.pending_instances.clear()
|
self.pending_instances.clear()
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ class PrimalSolutionComponent(Component):
|
|||||||
features = VariableFeaturesExtractor().extract(training_instances)
|
features = VariableFeaturesExtractor().extract(training_instances)
|
||||||
solutions = SolutionExtractor().extract(training_instances)
|
solutions = SolutionExtractor().extract(training_instances)
|
||||||
|
|
||||||
for category in tqdm(features.keys(), desc="Fit (Primal)"):
|
for category in tqdm(features.keys(), desc="Fit (primal)"):
|
||||||
x_train = features[category]
|
x_train = features[category]
|
||||||
y_train = solutions[category]
|
y_train = solutions[category]
|
||||||
for label in [0, 1]:
|
for label in [0, 1]:
|
||||||
@@ -116,7 +116,8 @@ class PrimalSolutionComponent(Component):
|
|||||||
def evaluate(self, instances):
|
def evaluate(self, instances):
|
||||||
ev = {"Fix zero": {},
|
ev = {"Fix zero": {},
|
||||||
"Fix one": {}}
|
"Fix one": {}}
|
||||||
for instance_idx in tqdm(range(len(instances))):
|
for instance_idx in tqdm(range(len(instances)),
|
||||||
|
desc="Evaluate (primal)"):
|
||||||
instance = instances[instance_idx]
|
instance = instances[instance_idx]
|
||||||
solution_actual = instance.solution
|
solution_actual = instance.solution
|
||||||
solution_pred = self.predict(instance)
|
solution_pred = self.predict(instance)
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ class VariableFeaturesExtractor(Extractor):
|
|||||||
def extract(self, instances):
|
def extract(self, instances):
|
||||||
result = {}
|
result = {}
|
||||||
for instance in tqdm(instances,
|
for instance in tqdm(instances,
|
||||||
desc="Extract var features",
|
desc="Extract (vars)",
|
||||||
disable=len(instances) < 5):
|
disable=len(instances) < 5):
|
||||||
instance_features = instance.get_instance_features()
|
instance_features = instance.get_instance_features()
|
||||||
var_split = self.split_variables(instance)
|
var_split = self.split_variables(instance)
|
||||||
@@ -60,7 +60,7 @@ class SolutionExtractor(Extractor):
|
|||||||
def extract(self, instances):
|
def extract(self, instances):
|
||||||
result = {}
|
result = {}
|
||||||
for instance in tqdm(instances,
|
for instance in tqdm(instances,
|
||||||
desc="Extract solution",
|
desc="Extract (solution)",
|
||||||
disable=len(instances) < 5):
|
disable=len(instances) < 5):
|
||||||
var_split = self.split_variables(instance)
|
var_split = self.split_variables(instance)
|
||||||
for (category, var_index_pairs) in var_split.items():
|
for (category, var_index_pairs) in var_split.items():
|
||||||
|
|||||||
Reference in New Issue
Block a user