Reformat source code with Black; add pre-commit hooks and CI checks

This commit is contained in:
2020-12-05 10:59:33 -06:00
parent 3823931382
commit d99600f101
49 changed files with 1291 additions and 972 deletions

View File

@@ -18,10 +18,12 @@ class DynamicLazyConstraintsComponent(Component):
"""
A component that predicts which lazy constraints to enforce.
"""
def __init__(self,
classifier=CountingClassifier(),
threshold=0.05):
def __init__(
self,
classifier=CountingClassifier(),
threshold=0.05,
):
self.violations = set()
self.count = {}
self.n_samples = 0
@@ -52,7 +54,7 @@ class DynamicLazyConstraintsComponent(Component):
def after_solve(self, solver, instance, model, results):
pass
def fit(self, training_instances):
logger.debug("Fitting...")
features = InstanceFeaturesExtractor().extract(training_instances)
@@ -68,10 +70,11 @@ class DynamicLazyConstraintsComponent(Component):
violation_to_instance_idx[v] = []
violation_to_instance_idx[v] += [idx]
for (v, classifier) in tqdm(self.classifiers.items(),
desc="Fit (lazy)",
disable=not sys.stdout.isatty(),
):
for (v, classifier) in tqdm(
self.classifiers.items(),
desc="Fit (lazy)",
disable=not sys.stdout.isatty(),
):
logger.debug("Training: %s" % (str(v)))
label = np.zeros(len(training_instances))
label[violation_to_instance_idx[v]] = 1.0
@@ -91,10 +94,11 @@ class DynamicLazyConstraintsComponent(Component):
all_violations = set()
for instance in instances:
all_violations |= set(instance.found_violated_lazy_constraints)
for idx in tqdm(range(len(instances)),
desc="Evaluate (lazy)",
disable=not sys.stdout.isatty(),
):
for idx in tqdm(
range(len(instances)),
desc="Evaluate (lazy)",
disable=not sys.stdout.isatty(),
):
instance = instances[idx]
condition_positive = set(instance.found_violated_lazy_constraints)
condition_negative = all_violations - condition_positive