mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Add types to remaining files; activate mypy's disallow_untyped_defs
This commit is contained in:
0
benchmark/__init__.py
Normal file
0
benchmark/__init__.py
Normal file
@@ -24,6 +24,7 @@ import importlib
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
@@ -46,7 +47,7 @@ logging.getLogger("pyomo.core").setLevel(logging.ERROR)
|
||||
logger = logging.getLogger("benchmark")
|
||||
|
||||
|
||||
def train(args):
|
||||
def train(args: Dict) -> None:
|
||||
basepath = args["<challenge>"]
|
||||
problem_name, challenge_name = args["<challenge>"].split("/")
|
||||
pkg = importlib.import_module(f"miplearn.problems.{problem_name}")
|
||||
@@ -76,7 +77,7 @@ def train(args):
|
||||
Path(done_filename).touch(exist_ok=True)
|
||||
|
||||
|
||||
def test_baseline(args):
|
||||
def test_baseline(args: Dict) -> None:
|
||||
basepath = args["<challenge>"]
|
||||
test_instances = [PickleGzInstance(f) for f in glob.glob(f"{basepath}/test/*.gz")]
|
||||
csv_filename = f"{basepath}/benchmark_baseline.csv"
|
||||
@@ -99,7 +100,7 @@ def test_baseline(args):
|
||||
benchmark.write_csv(csv_filename)
|
||||
|
||||
|
||||
def test_ml(args):
|
||||
def test_ml(args: Dict) -> None:
|
||||
basepath = args["<challenge>"]
|
||||
test_instances = [PickleGzInstance(f) for f in glob.glob(f"{basepath}/test/*.gz")]
|
||||
train_instances = [PickleGzInstance(f) for f in glob.glob(f"{basepath}/train/*.gz")]
|
||||
@@ -133,7 +134,7 @@ def test_ml(args):
|
||||
benchmark.write_csv(csv_filename)
|
||||
|
||||
|
||||
def charts(args):
|
||||
def charts(args: Dict) -> None:
|
||||
basepath = args["<challenge>"]
|
||||
sns.set_style("whitegrid")
|
||||
sns.set_palette("Blues_r")
|
||||
@@ -244,7 +245,7 @@ def charts(args):
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def main() -> None:
|
||||
args = docopt(__doc__)
|
||||
if args["train"]:
|
||||
train(args)
|
||||
@@ -254,3 +255,7 @@ if __name__ == "__main__":
|
||||
test_ml(args)
|
||||
if args["charts"]:
|
||||
charts(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user