Add types to remaining files; activate mypy's disallow_untyped_defs

This commit is contained in:
2021-04-07 21:25:30 -05:00
parent f5606efb72
commit e9cd6d1715
21 changed files with 102 additions and 64 deletions

0
benchmark/__init__.py Normal file
View File

View File

@@ -24,6 +24,7 @@ import importlib
import logging
import os
from pathlib import Path
from typing import Dict
import matplotlib.pyplot as plt
import pandas as pd
@@ -46,7 +47,7 @@ logging.getLogger("pyomo.core").setLevel(logging.ERROR)
logger = logging.getLogger("benchmark")
def train(args):
def train(args: Dict) -> None:
basepath = args["<challenge>"]
problem_name, challenge_name = args["<challenge>"].split("/")
pkg = importlib.import_module(f"miplearn.problems.{problem_name}")
@@ -76,7 +77,7 @@ def train(args):
Path(done_filename).touch(exist_ok=True)
def test_baseline(args):
def test_baseline(args: Dict) -> None:
basepath = args["<challenge>"]
test_instances = [PickleGzInstance(f) for f in glob.glob(f"{basepath}/test/*.gz")]
csv_filename = f"{basepath}/benchmark_baseline.csv"
@@ -99,7 +100,7 @@ def test_baseline(args):
benchmark.write_csv(csv_filename)
def test_ml(args):
def test_ml(args: Dict) -> None:
basepath = args["<challenge>"]
test_instances = [PickleGzInstance(f) for f in glob.glob(f"{basepath}/test/*.gz")]
train_instances = [PickleGzInstance(f) for f in glob.glob(f"{basepath}/train/*.gz")]
@@ -133,7 +134,7 @@ def test_ml(args):
benchmark.write_csv(csv_filename)
def charts(args):
def charts(args: Dict) -> None:
basepath = args["<challenge>"]
sns.set_style("whitegrid")
sns.set_palette("Blues_r")
@@ -244,7 +245,7 @@ def charts(args):
)
if __name__ == "__main__":
def main() -> None:
args = docopt(__doc__)
if args["train"]:
train(args)
@@ -254,3 +255,7 @@ if __name__ == "__main__":
test_ml(args)
if args["charts"]:
charts(args)
if __name__ == "__main__":
main()