mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 17:38:51 -06:00
Add customizable branch priority; add more metrics to BenchmarkRunner
This commit is contained in:
@@ -28,12 +28,12 @@ def test_benchmark():
|
||||
}
|
||||
benchmark = BenchmarkRunner(test_solvers)
|
||||
benchmark.load_fit("data.bin")
|
||||
benchmark.parallel_solve(test_instances, n_jobs=2)
|
||||
assert benchmark.raw_results().values.shape == (6,6)
|
||||
benchmark.parallel_solve(test_instances, n_jobs=2, n_trials=2)
|
||||
assert benchmark.raw_results().values.shape == (12,12)
|
||||
|
||||
benchmark.save_results("/tmp/benchmark.csv")
|
||||
assert os.path.isfile("/tmp/benchmark.csv")
|
||||
|
||||
benchmark = BenchmarkRunner(test_solvers)
|
||||
benchmark.load_results("/tmp/benchmark.csv")
|
||||
assert benchmark.raw_results().values.shape == (6,6)
|
||||
assert benchmark.raw_results().values.shape == (12,12)
|
||||
|
||||
@@ -40,4 +40,11 @@ def test_parallel_solve():
|
||||
solver = LearningSolver()
|
||||
solver.parallel_solve(instances, n_jobs=3)
|
||||
assert len(solver.x_train[0]) == 10
|
||||
assert len(solver.y_train[0]) == 10
|
||||
assert len(solver.y_train[0]) == 10
|
||||
|
||||
def test_solver_random_branch_priority():
|
||||
instance = KnapsackInstance2(weights=[23., 26., 20., 18.],
|
||||
prices=[505., 352., 458., 220.],
|
||||
capacity=67.)
|
||||
solver = LearningSolver(branch_priority=[1, 2, 3, 4])
|
||||
solver.solve(instance, tee=True)
|
||||
Reference in New Issue
Block a user