Reformat source code with Black; add pre-commit hooks and CI checks

This commit is contained in:
2020-12-05 10:59:33 -06:00
parent 3823931382
commit d99600f101
49 changed files with 1291 additions and 972 deletions

View File

@@ -9,20 +9,22 @@ from miplearn.problems.knapsack import KnapsackInstance, GurobiKnapsackInstance
def _get_instance(solver):
def _is_subclass_or_instance(solver, parentClass):
return isinstance(solver, parentClass) or (isclass(solver) and issubclass(solver, parentClass))
return isinstance(solver, parentClass) or (
isclass(solver) and issubclass(solver, parentClass)
)
if _is_subclass_or_instance(solver, BasePyomoSolver):
return KnapsackInstance(
weights=[23., 26., 20., 18.],
prices=[505., 352., 458., 220.],
capacity=67.,
weights=[23.0, 26.0, 20.0, 18.0],
prices=[505.0, 352.0, 458.0, 220.0],
capacity=67.0,
)
if _is_subclass_or_instance(solver, GurobiSolver):
return GurobiKnapsackInstance(
weights=[23., 26., 20., 18.],
prices=[505., 352., 458., 220.],
capacity=67.,
weights=[23.0, 26.0, 20.0, 18.0],
prices=[505.0, 352.0, 458.0, 220.0],
capacity=67.0,
)
assert False

View File

@@ -16,6 +16,7 @@ logger = logging.getLogger(__name__)
def test_redirect_output():
import sys
original_stdout = sys.stdout
io = StringIO()
with RedirectOutput([io]):
@@ -31,36 +32,42 @@ def test_internal_solver_warm_starts():
model = instance.to_model()
solver = solver_class()
solver.set_instance(instance, model)
solver.set_warm_start({
"x": {
0: 1.0,
1: 0.0,
2: 0.0,
3: 1.0,
solver.set_warm_start(
{
"x": {
0: 1.0,
1: 0.0,
2: 0.0,
3: 1.0,
}
}
})
)
stats = solver.solve(tee=True)
assert stats["Warm start value"] == 725.0
solver.set_warm_start({
"x": {
0: 1.0,
1: 1.0,
2: 1.0,
3: 1.0,
solver.set_warm_start(
{
"x": {
0: 1.0,
1: 1.0,
2: 1.0,
3: 1.0,
}
}
})
)
stats = solver.solve(tee=True)
assert stats["Warm start value"] is None
solver.fix({
"x": {
0: 1.0,
1: 0.0,
2: 0.0,
3: 1.0,
solver.fix(
{
"x": {
0: 1.0,
1: 0.0,
2: 0.0,
3: 1.0,
}
}
})
)
stats = solver.solve(tee=True)
assert stats["Lower bound"] == 725.0
assert stats["Upper bound"] == 725.0

View File

@@ -20,11 +20,13 @@ def test_learning_solver():
for internal_solver in _get_internal_solvers():
logger.info("Solver: %s" % internal_solver)
instance = _get_instance(internal_solver)
solver = LearningSolver(time_limit=300,
gap_tolerance=1e-3,
threads=1,
solver=internal_solver,
mode=mode)
solver = LearningSolver(
time_limit=300,
gap_tolerance=1e-3,
threads=1,
solver=internal_solver,
mode=mode,
)
solver.solve(instance)
assert instance.solution["x"][0] == 1.0
@@ -74,37 +76,36 @@ def test_solve_fit_from_disk():
filenames = []
for k in range(3):
instance = _get_instance(internal_solver)
with tempfile.NamedTemporaryFile(suffix=".pkl",
delete=False) as file:
with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as file:
filenames += [file.name]
pickle.dump(instance, file)
# Test: solve
solver = LearningSolver(solver=internal_solver)
solver.solve(filenames[0])
with open(filenames[0], "rb") as file:
instance = pickle.load(file)
assert hasattr(instance, "solution")
# Test: parallel_solve
solver.parallel_solve(filenames)
for filename in filenames:
with open(filename, "rb") as file:
instance = pickle.load(file)
assert hasattr(instance, "solution")
# Test: solve (with specified output)
output = [f + ".out" for f in filenames]
solver.solve(filenames[0], output=output[0])
assert os.path.isfile(output[0])
# Test: parallel_solve (with specified output)
solver.parallel_solve(filenames, output=output)
for filename in output:
assert os.path.isfile(filename)
# Delete temporary files
for filename in filenames:
os.remove(filename)
for filename in output:
os.remove(filename)
os.remove(filename)