Reformat additional files

pull/3/head
Alinson S. Xavier 5 years ago
parent d99600f101
commit 718ac0da06

@ -35,7 +35,7 @@ uninstall:
$(PIP) uninstall miplearn $(PIP) uninstall miplearn
reformat: reformat:
$(PYTHON) -m black miplearn $(PYTHON) -m black .
test: test:
$(PYTEST) $(PYTEST_ARGS) $(PYTEST) $(PYTEST_ARGS)

@ -14,22 +14,26 @@ Usage:
Options: Options:
-h --help Show this screen -h --help Show this screen
""" """
from docopt import docopt import importlib
import importlib, pathlib
from miplearn import (LearningSolver, BenchmarkRunner)
from numpy import median
import pyomo.environ as pe
import pickle
import logging import logging
import pathlib
import pickle
import sys import sys
logging.basicConfig(format='%(asctime)s %(levelname).1s %(name)s: %(message)12s', from docopt import docopt
datefmt='%H:%M:%S', from numpy import median
from miplearn import LearningSolver, BenchmarkRunner
logging.basicConfig(
format="%(asctime)s %(levelname).1s %(name)s: %(message)12s",
datefmt="%H:%M:%S",
level=logging.INFO, level=logging.INFO,
stream=sys.stdout) stream=sys.stdout,
logging.getLogger('gurobipy').setLevel(logging.ERROR) )
logging.getLogger('pyomo.core').setLevel(logging.ERROR) logging.getLogger("gurobipy").setLevel(logging.ERROR)
logging.getLogger('miplearn').setLevel(logging.INFO) logging.getLogger("pyomo.core").setLevel(logging.ERROR)
logging.getLogger("miplearn").setLevel(logging.INFO)
logger = logging.getLogger("benchmark") logger = logging.getLogger("benchmark")
n_jobs = 10 n_jobs = 10
@ -50,6 +54,7 @@ def save(obj, filename):
def load(filename): def load(filename):
import pickle import pickle
with open(filename, "rb") as file: with open(filename, "rb") as file:
return pickle.load(file) return pickle.load(file)
@ -60,9 +65,11 @@ def train():
challenge = getattr(pkg, challenge_name)() challenge = getattr(pkg, challenge_name)()
train_instances = challenge.training_instances train_instances = challenge.training_instances
test_instances = challenge.test_instances test_instances = challenge.test_instances
solver = LearningSolver(time_limit=train_time_limit, solver = LearningSolver(
time_limit=train_time_limit,
solver=internal_solver, solver=internal_solver,
components={}) components={},
)
solver.parallel_solve(train_instances, n_jobs=n_jobs) solver.parallel_solve(train_instances, n_jobs=n_jobs)
save(train_instances, "%s/train_instances.bin" % basepath) save(train_instances, "%s/train_instances.bin" % basepath)
save(test_instances, "%s/test_instances.bin" % basepath) save(test_instances, "%s/test_instances.bin" % basepath)
@ -109,6 +116,7 @@ def test_ml():
def charts(): def charts():
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import seaborn as sns import seaborn as sns
sns.set_style("whitegrid") sns.set_style("whitegrid")
sns.set_palette("Blues_r") sns.set_palette("Blues_r")
benchmark = BenchmarkRunner({}) benchmark = BenchmarkRunner({})
@ -126,69 +134,70 @@ def charts():
obj_column = "Lower Bound" obj_column = "Lower Bound"
predicted_obj_column = "Predicted LB" predicted_obj_column = "Predicted LB"
palette={ palette = {"baseline": "#9b59b6", "ml-exact": "#3498db", "ml-heuristic": "#95a5a6"}
"baseline": "#9b59b6", fig, (ax1, ax2, ax3, ax4) = plt.subplots(
"ml-exact": "#3498db", nrows=1,
"ml-heuristic": "#95a5a6"
}
fig, (ax1, ax2, ax3, ax4) = plt.subplots(nrows=1,
ncols=4, ncols=4,
figsize=(12,4), figsize=(12, 4),
gridspec_kw={'width_ratios': [2, 1, 1, 2]}, gridspec_kw={"width_ratios": [2, 1, 1, 2]},
) )
sns.stripplot(x="Solver", sns.stripplot(
x="Solver",
y="Wallclock Time", y="Wallclock Time",
data=results, data=results,
ax=ax1, ax=ax1,
jitter=0.25, jitter=0.25,
palette=palette, palette=palette,
size=4.0, size=4.0,
); )
sns.barplot(x="Solver", sns.barplot(
x="Solver",
y="Wallclock Time", y="Wallclock Time",
data=results, data=results,
ax=ax1, ax=ax1,
errwidth=0., errwidth=0.0,
alpha=0.4, alpha=0.4,
palette=palette, palette=palette,
estimator=median, estimator=median,
); )
ax1.set(ylabel='Wallclock Time (s)') ax1.set(ylabel="Wallclock Time (s)")
ax2.set_ylim(-0.5, 5.5) ax2.set_ylim(-0.5, 5.5)
sns.stripplot(x="Solver", sns.stripplot(
x="Solver",
y="Gap (%)", y="Gap (%)",
jitter=0.25, jitter=0.25,
data=results[results["Solver"] != "ml-heuristic"], data=results[results["Solver"] != "ml-heuristic"],
ax=ax2, ax=ax2,
palette=palette, palette=palette,
size=4.0, size=4.0,
); )
ax3.set_ylim(0.95,1.05) ax3.set_ylim(0.95, 1.05)
sns.stripplot(x="Solver", sns.stripplot(
x="Solver",
y=primal_column, y=primal_column,
jitter=0.25, jitter=0.25,
data=results[results["Solver"] == "ml-heuristic"], data=results[results["Solver"] == "ml-heuristic"],
ax=ax3, ax=ax3,
palette=palette, palette=palette,
); )
sns.scatterplot(x=obj_column, sns.scatterplot(
x=obj_column,
y=predicted_obj_column, y=predicted_obj_column,
hue="Solver", hue="Solver",
data=results[results["Solver"] == "ml-exact"], data=results[results["Solver"] == "ml-exact"],
ax=ax4, ax=ax4,
palette=palette, palette=palette,
); )
xlim, ylim = ax4.get_xlim(), ax4.get_ylim() xlim, ylim = ax4.get_xlim(), ax4.get_ylim()
ax4.plot([-1e10, 1e10], [-1e10, 1e10], ls='-', color="#cccccc"); ax4.plot([-1e10, 1e10], [-1e10, 1e10], ls="-", color="#cccccc")
ax4.set_xlim(xlim) ax4.set_xlim(xlim)
ax4.set_ylim(ylim) ax4.set_ylim(ylim)
ax4.get_legend().remove() ax4.get_legend().remove()
fig.tight_layout() fig.tight_layout()
plt.savefig("%s/performance.png" % basepath, plt.savefig("%s/performance.png" % basepath, bbox_inches="tight", dpi=150)
bbox_inches='tight',
dpi=150)
if __name__ == "__main__": if __name__ == "__main__":
if args["train"]: if args["train"]:

@ -4,27 +4,27 @@ with open("README.md", "r") as fh:
long_description = fh.read() long_description = fh.read()
setup( setup(
name='miplearn', name="miplearn",
version='0.2.0', version="0.2.0",
author='Alinson S. Xavier', author="Alinson S. Xavier",
author_email='axavier@anl.gov', author_email="axavier@anl.gov",
description="Extensible framework for Learning-Enhanced Mixed-Integer Optimization", description="Extensible framework for Learning-Enhanced Mixed-Integer Optimization",
long_description=long_description, long_description=long_description,
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
url="https://github.com/ANL-CEEESA/MIPLearn/", url="https://github.com/ANL-CEEESA/MIPLearn/",
packages=find_namespace_packages(), packages=find_namespace_packages(),
python_requires='>=3.6', python_requires=">=3.6",
install_requires=[ install_requires=[
'docopt', "docopt",
'matplotlib', "matplotlib",
'networkx', "networkx",
'numpy', "numpy",
'pandas', "pandas",
'p_tqdm', "p_tqdm",
'pyomo', "pyomo",
'python-markdown-math', "python-markdown-math",
'seaborn', "seaborn",
'sklearn', "sklearn",
'tqdm', "tqdm",
], ],
) )

Loading…
Cancel
Save