Reformat source code with Black; add pre-commit hooks and CI checks

This commit is contained in:
2020-12-05 10:59:33 -06:00
parent 3823931382
commit d99600f101
49 changed files with 1291 additions and 972 deletions

View File

@@ -19,13 +19,14 @@ class LazyConstraint:
class StaticLazyConstraintsComponent(Component):
def __init__(self,
classifier=CountingClassifier(),
threshold=0.05,
use_two_phase_gap=True,
large_gap=1e-2,
violation_tolerance=-0.5,
):
def __init__(
self,
classifier=CountingClassifier(),
threshold=0.05,
use_two_phase_gap=True,
large_gap=1e-2,
violation_tolerance=-0.5,
):
self.threshold = threshold
self.classifier_prototype = classifier
self.classifiers = {}
@@ -74,32 +75,38 @@ class StaticLazyConstraintsComponent(Component):
logger.debug("Finding violated lazy constraints...")
constraints_to_add = []
for c in self.pool:
if not solver.internal_solver.is_constraint_satisfied(c.obj,
tol=self.violation_tolerance):
if not solver.internal_solver.is_constraint_satisfied(
c.obj, tol=self.violation_tolerance
):
constraints_to_add.append(c)
for c in constraints_to_add:
self.pool.remove(c)
solver.internal_solver.add_constraint(c.obj)
instance.found_violated_lazy_constraints += [c.cid]
if len(constraints_to_add) > 0:
logger.info("%8d lazy constraints added %8d in the pool" % (len(constraints_to_add), len(self.pool)))
logger.info(
"%8d lazy constraints added %8d in the pool"
% (len(constraints_to_add), len(self.pool))
)
return True
else:
return False
def fit(self, training_instances):
training_instances = [t
for t in training_instances
if hasattr(t, "found_violated_lazy_constraints")]
training_instances = [
t
for t in training_instances
if hasattr(t, "found_violated_lazy_constraints")
]
logger.debug("Extracting x and y...")
x = self.x(training_instances)
y = self.y(training_instances)
logger.debug("Fitting...")
for category in tqdm(x.keys(),
desc="Fit (lazy)",
disable=not sys.stdout.isatty()):
for category in tqdm(
x.keys(), desc="Fit (lazy)", disable=not sys.stdout.isatty()
):
if category not in self.classifiers:
self.classifiers[category] = deepcopy(self.classifier_prototype)
self.classifiers[category].fit(x[category], y[category])
@@ -121,8 +128,10 @@ class StaticLazyConstraintsComponent(Component):
x[category] = []
constraints[category] = []
x[category] += [instance.get_constraint_features(cid)]
c = LazyConstraint(cid=cid,
obj=solver.internal_solver.extract_constraint(cid))
c = LazyConstraint(
cid=cid,
obj=solver.internal_solver.extract_constraint(cid),
)
constraints[category] += [c]
self.pool.append(c)
logger.info("%8d lazy constraints extracted" % len(self.pool))
@@ -141,7 +150,13 @@ class StaticLazyConstraintsComponent(Component):
self.pool.remove(c)
solver.internal_solver.add_constraint(c.obj)
instance.found_violated_lazy_constraints += [c.cid]
logger.info("%8d lazy constraints added %8d in the pool" % (n_added, len(self.pool)))
logger.info(
"%8d lazy constraints added %8d in the pool"
% (
n_added,
len(self.pool),
)
)
def _collect_constraints(self, train_instances):
constraints = {}