Reformat additional files

pull/3/head
Alinson S. Xavier 5 years ago
parent d99600f101
commit 718ac0da06

@ -35,7 +35,7 @@ uninstall:
$(PIP) uninstall miplearn $(PIP) uninstall miplearn
reformat: reformat:
$(PYTHON) -m black miplearn $(PYTHON) -m black .
test: test:
$(PYTEST) $(PYTEST_ARGS) $(PYTEST) $(PYTEST_ARGS)

@ -14,22 +14,26 @@ Usage:
Options: Options:
-h --help Show this screen -h --help Show this screen
""" """
from docopt import docopt import importlib
import importlib, pathlib
from miplearn import (LearningSolver, BenchmarkRunner)
from numpy import median
import pyomo.environ as pe
import pickle
import logging import logging
import pathlib
import pickle
import sys import sys
logging.basicConfig(format='%(asctime)s %(levelname).1s %(name)s: %(message)12s', from docopt import docopt
datefmt='%H:%M:%S', from numpy import median
level=logging.INFO,
stream=sys.stdout) from miplearn import LearningSolver, BenchmarkRunner
logging.getLogger('gurobipy').setLevel(logging.ERROR)
logging.getLogger('pyomo.core').setLevel(logging.ERROR) logging.basicConfig(
logging.getLogger('miplearn').setLevel(logging.INFO) format="%(asctime)s %(levelname).1s %(name)s: %(message)12s",
datefmt="%H:%M:%S",
level=logging.INFO,
stream=sys.stdout,
)
logging.getLogger("gurobipy").setLevel(logging.ERROR)
logging.getLogger("pyomo.core").setLevel(logging.ERROR)
logging.getLogger("miplearn").setLevel(logging.INFO)
logger = logging.getLogger("benchmark") logger = logging.getLogger("benchmark")
n_jobs = 10 n_jobs = 10
@ -46,28 +50,31 @@ def save(obj, filename):
logger.info("Writing %s..." % filename) logger.info("Writing %s..." % filename)
with open(filename, "wb") as file: with open(filename, "wb") as file:
pickle.dump(obj, file) pickle.dump(obj, file)
def load(filename): def load(filename):
import pickle import pickle
with open(filename, "rb") as file: with open(filename, "rb") as file:
return pickle.load(file) return pickle.load(file)
def train(): def train():
problem_name, challenge_name = args["<challenge>"].split("/") problem_name, challenge_name = args["<challenge>"].split("/")
pkg = importlib.import_module("miplearn.problems.%s" % problem_name) pkg = importlib.import_module("miplearn.problems.%s" % problem_name)
challenge = getattr(pkg, challenge_name)() challenge = getattr(pkg, challenge_name)()
train_instances = challenge.training_instances train_instances = challenge.training_instances
test_instances = challenge.test_instances test_instances = challenge.test_instances
solver = LearningSolver(time_limit=train_time_limit, solver = LearningSolver(
solver=internal_solver, time_limit=train_time_limit,
components={}) solver=internal_solver,
components={},
)
solver.parallel_solve(train_instances, n_jobs=n_jobs) solver.parallel_solve(train_instances, n_jobs=n_jobs)
save(train_instances, "%s/train_instances.bin" % basepath) save(train_instances, "%s/train_instances.bin" % basepath)
save(test_instances, "%s/test_instances.bin" % basepath) save(test_instances, "%s/test_instances.bin" % basepath)
def test_baseline(): def test_baseline():
test_instances = load("%s/test_instances.bin" % basepath) test_instances = load("%s/test_instances.bin" % basepath)
solvers = { solvers = {
@ -79,8 +86,8 @@ def test_baseline():
benchmark = BenchmarkRunner(solvers) benchmark = BenchmarkRunner(solvers)
benchmark.parallel_solve(test_instances, n_jobs=n_jobs) benchmark.parallel_solve(test_instances, n_jobs=n_jobs)
benchmark.save_results("%s/benchmark_baseline.csv" % basepath) benchmark.save_results("%s/benchmark_baseline.csv" % basepath)
def test_ml(): def test_ml():
logger.info("Loading instances...") logger.info("Loading instances...")
train_instances = load("%s/train_instances.bin" % basepath) train_instances = load("%s/train_instances.bin" % basepath)
@ -105,17 +112,18 @@ def test_ml():
benchmark.parallel_solve(test_instances, n_jobs=n_jobs) benchmark.parallel_solve(test_instances, n_jobs=n_jobs)
benchmark.save_results("%s/benchmark_ml.csv" % basepath) benchmark.save_results("%s/benchmark_ml.csv" % basepath)
def charts(): def charts():
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import seaborn as sns import seaborn as sns
sns.set_style("whitegrid") sns.set_style("whitegrid")
sns.set_palette("Blues_r") sns.set_palette("Blues_r")
benchmark = BenchmarkRunner({}) benchmark = BenchmarkRunner({})
benchmark.load_results("%s/benchmark_ml.csv" % basepath) benchmark.load_results("%s/benchmark_ml.csv" % basepath)
results = benchmark.raw_results() results = benchmark.raw_results()
results["Gap (%)"] = results["Gap"] * 100.0 results["Gap (%)"] = results["Gap"] * 100.0
sense = results.loc[0, "Sense"] sense = results.loc[0, "Sense"]
if sense == "min": if sense == "min":
primal_column = "Relative Upper Bound" primal_column = "Relative Upper Bound"
@ -125,70 +133,71 @@ def charts():
primal_column = "Relative Lower Bound" primal_column = "Relative Lower Bound"
obj_column = "Lower Bound" obj_column = "Lower Bound"
predicted_obj_column = "Predicted LB" predicted_obj_column = "Predicted LB"
palette={ palette = {"baseline": "#9b59b6", "ml-exact": "#3498db", "ml-heuristic": "#95a5a6"}
"baseline": "#9b59b6", fig, (ax1, ax2, ax3, ax4) = plt.subplots(
"ml-exact": "#3498db", nrows=1,
"ml-heuristic": "#95a5a6" ncols=4,
} figsize=(12, 4),
fig, (ax1, ax2, ax3, ax4) = plt.subplots(nrows=1, gridspec_kw={"width_ratios": [2, 1, 1, 2]},
ncols=4, )
figsize=(12,4), sns.stripplot(
gridspec_kw={'width_ratios': [2, 1, 1, 2]}, x="Solver",
) y="Wallclock Time",
sns.stripplot(x="Solver", data=results,
y="Wallclock Time", ax=ax1,
data=results, jitter=0.25,
ax=ax1, palette=palette,
jitter=0.25, size=4.0,
palette=palette, )
size=4.0, sns.barplot(
); x="Solver",
sns.barplot(x="Solver", y="Wallclock Time",
y="Wallclock Time", data=results,
data=results, ax=ax1,
ax=ax1, errwidth=0.0,
errwidth=0., alpha=0.4,
alpha=0.4, palette=palette,
palette=palette, estimator=median,
estimator=median, )
); ax1.set(ylabel="Wallclock Time (s)")
ax1.set(ylabel='Wallclock Time (s)')
ax2.set_ylim(-0.5, 5.5) ax2.set_ylim(-0.5, 5.5)
sns.stripplot(x="Solver", sns.stripplot(
y="Gap (%)", x="Solver",
jitter=0.25, y="Gap (%)",
data=results[results["Solver"] != "ml-heuristic"], jitter=0.25,
ax=ax2, data=results[results["Solver"] != "ml-heuristic"],
palette=palette, ax=ax2,
size=4.0, palette=palette,
); size=4.0,
ax3.set_ylim(0.95,1.05) )
sns.stripplot(x="Solver", ax3.set_ylim(0.95, 1.05)
y=primal_column, sns.stripplot(
jitter=0.25, x="Solver",
data=results[results["Solver"] == "ml-heuristic"], y=primal_column,
ax=ax3, jitter=0.25,
palette=palette, data=results[results["Solver"] == "ml-heuristic"],
); ax=ax3,
palette=palette,
sns.scatterplot(x=obj_column, )
y=predicted_obj_column,
hue="Solver", sns.scatterplot(
data=results[results["Solver"] == "ml-exact"], x=obj_column,
ax=ax4, y=predicted_obj_column,
palette=palette, hue="Solver",
); data=results[results["Solver"] == "ml-exact"],
ax=ax4,
palette=palette,
)
xlim, ylim = ax4.get_xlim(), ax4.get_ylim() xlim, ylim = ax4.get_xlim(), ax4.get_ylim()
ax4.plot([-1e10, 1e10], [-1e10, 1e10], ls='-', color="#cccccc"); ax4.plot([-1e10, 1e10], [-1e10, 1e10], ls="-", color="#cccccc")
ax4.set_xlim(xlim) ax4.set_xlim(xlim)
ax4.set_ylim(ylim) ax4.set_ylim(ylim)
ax4.get_legend().remove() ax4.get_legend().remove()
fig.tight_layout() fig.tight_layout()
plt.savefig("%s/performance.png" % basepath, plt.savefig("%s/performance.png" % basepath, bbox_inches="tight", dpi=150)
bbox_inches='tight',
dpi=150)
if __name__ == "__main__": if __name__ == "__main__":
if args["train"]: if args["train"]:

@ -4,27 +4,27 @@ with open("README.md", "r") as fh:
long_description = fh.read() long_description = fh.read()
setup( setup(
name='miplearn', name="miplearn",
version='0.2.0', version="0.2.0",
author='Alinson S. Xavier', author="Alinson S. Xavier",
author_email='axavier@anl.gov', author_email="axavier@anl.gov",
description="Extensible framework for Learning-Enhanced Mixed-Integer Optimization", description="Extensible framework for Learning-Enhanced Mixed-Integer Optimization",
long_description=long_description, long_description=long_description,
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
url="https://github.com/ANL-CEEESA/MIPLearn/", url="https://github.com/ANL-CEEESA/MIPLearn/",
packages=find_namespace_packages(), packages=find_namespace_packages(),
python_requires='>=3.6', python_requires=">=3.6",
install_requires=[ install_requires=[
'docopt', "docopt",
'matplotlib', "matplotlib",
'networkx', "networkx",
'numpy', "numpy",
'pandas', "pandas",
'p_tqdm', "p_tqdm",
'pyomo', "pyomo",
'python-markdown-math', "python-markdown-math",
'seaborn', "seaborn",
'sklearn', "sklearn",
'tqdm', "tqdm",
], ],
) )

Loading…
Cancel
Save