Add after_solve_lp callback; make dict keys consistent

This commit is contained in:
2021-03-30 10:05:28 -05:00
parent 6ae052c8d0
commit 3b61a15ead
9 changed files with 115 additions and 54 deletions

View File

@@ -68,7 +68,7 @@ def test_convert_tight_infeasibility():
solver = LearningSolver(
solver=GurobiSolver,
components=[comp],
solve_lp_first=False,
solve_lp=False,
)
instance = SampleInstance()
stats = solver.solve(instance)
@@ -91,7 +91,7 @@ def test_convert_tight_suboptimality():
solver = LearningSolver(
solver=GurobiSolver,
components=[comp],
solve_lp_first=False,
solve_lp=False,
)
instance = SampleInstance()
stats = solver.solve(instance)
@@ -114,7 +114,7 @@ def test_convert_tight_optimal():
solver = LearningSolver(
solver=GurobiSolver,
components=[comp],
solve_lp_first=False,
solve_lp=False,
)
instance = SampleInstance()
stats = solver.solve(instance)

View File

@@ -93,8 +93,8 @@ def test_internal_solver():
stats = solver.solve_lp()
assert not solver.is_infeasible()
assert round(stats["Optimal value"], 3) == 1287.923
assert len(stats["Log"]) > 100
assert round(stats["LP value"], 3) == 1287.923
assert len(stats["LP log"]) > 100
solution = solver.get_solution()
assert round(solution["x"][0], 3) == 1.000
@@ -104,7 +104,7 @@ def test_internal_solver():
stats = solver.solve(tee=True)
assert not solver.is_infeasible()
assert len(stats["Log"]) > 100
assert len(stats["MIP log"]) > 100
assert stats["Lower bound"] == 1183.0
assert stats["Upper bound"] == 1183.0
assert stats["Sense"] == "max"
@@ -198,7 +198,7 @@ def test_infeasible_instance():
stats = solver.solve_lp()
assert solver.get_solution() is None
assert stats["Optimal value"] is None
assert stats["LP value"] is None
assert solver.get_value("x", 0) is None

View File

@@ -57,7 +57,7 @@ def test_solve_without_lp():
instance = _get_knapsack_instance(internal_solver)
solver = LearningSolver(
solver=internal_solver,
solve_lp_first=False,
solve_lp=False,
)
solver.solve(instance)
solver.fit([instance])

View File

@@ -29,7 +29,7 @@ def test_benchmark():
benchmark = BenchmarkRunner(test_solvers)
benchmark.fit(train_instances)
benchmark.parallel_solve(test_instances, n_jobs=2, n_trials=2)
assert benchmark.results.values.shape == (12, 17)
assert benchmark.results.values.shape == (12, 18)
benchmark.write_csv("/tmp/benchmark.csv")
assert os.path.isfile("/tmp/benchmark.csv")