Update benchmark scripts

pull/1/head
Alinson S. Xavier 6 years ago
parent 9e4004ae8a
commit 15ee7f51ee

@ -30,6 +30,7 @@ logging.getLogger('pyomo.core').setLevel(logging.ERROR)
n_jobs = 10 n_jobs = 10
time_limit = 300 time_limit = 300
internal_solver = "gurobi"
args = docopt(__doc__) args = docopt(__doc__)
basepath = args["<challenge>"] basepath = args["<challenge>"]
@ -52,9 +53,11 @@ def train():
problem_name, challenge_name = args["<challenge>"].split("/") problem_name, challenge_name = args["<challenge>"].split("/")
pkg = importlib.import_module("miplearn.problems.%s" % problem_name) pkg = importlib.import_module("miplearn.problems.%s" % problem_name)
challenge = getattr(pkg, challenge_name)() challenge = getattr(pkg, challenge_name)()
train_instances = challenge.training_instances[:10] train_instances = challenge.training_instances
test_instances = challenge.test_instances[:10] test_instances = challenge.test_instances
solver = LearningSolver(time_limit=time_limit, components={}) solver = LearningSolver(time_limit=time_limit,
solver=internal_solver,
components={})
solver.parallel_solve(train_instances, n_jobs=n_jobs) solver.parallel_solve(train_instances, n_jobs=n_jobs)
solver.fit(n_jobs=n_jobs) solver.fit(n_jobs=n_jobs)
save(train_instances, "%s/train_instances.bin" % basepath) save(train_instances, "%s/train_instances.bin" % basepath)
@ -157,4 +160,4 @@ if __name__ == "__main__":
#if args["test-ml"]: #if args["test-ml"]:
# test_ml() # test_ml()
#if args["charts"]: #if args["charts"]:
# charts() # charts()

Loading…
Cancel
Save