From 3f4336f902005afc8069804d33232f4016de81bc Mon Sep 17 00:00:00 2001 From: "Alinson S. Xavier" Date: Fri, 9 Apr 2021 08:18:54 -0500 Subject: [PATCH] Always remove .mypy_cache; fix more mypy tests --- Makefile | 1 + benchmark/benchmark.py | 17 ++++++++++++----- tests/test_benchmark.py | 10 +++++++--- 3 files changed, 20 insertions(+), 8 deletions(-) diff --git a/Makefile b/Makefile index 8213125..aeb5739 100644 --- a/Makefile +++ b/Makefile @@ -42,6 +42,7 @@ reformat: $(PYTHON) -m black . test: + rm -rf .mypy_cache $(MYPY) -p miplearn $(MYPY) -p tests $(MYPY) -p benchmark diff --git a/benchmark/benchmark.py b/benchmark/benchmark.py index 3e489dc..47c2a14 100755 --- a/benchmark/benchmark.py +++ b/benchmark/benchmark.py @@ -24,7 +24,7 @@ import importlib import logging import os from pathlib import Path -from typing import Dict +from typing import Dict, List import matplotlib.pyplot as plt import pandas as pd @@ -39,6 +39,7 @@ from miplearn import ( setup_logger, PickleGzInstance, write_pickle_gz_multiple, + Instance, ) setup_logger() @@ -59,7 +60,7 @@ def train(args: Dict) -> None: done_filename = f"{basepath}/train/done" if not os.path.isfile(done_filename): - train_instances = [ + train_instances: List[Instance] = [ PickleGzInstance(f) for f in glob.glob(f"{basepath}/train/*.gz") ] solver = LearningSolver( @@ -79,7 +80,9 @@ def train(args: Dict) -> None: def test_baseline(args: Dict) -> None: basepath = args[""] - test_instances = [PickleGzInstance(f) for f in glob.glob(f"{basepath}/test/*.gz")] + test_instances: List[Instance] = [ + PickleGzInstance(f) for f in glob.glob(f"{basepath}/test/*.gz") + ] csv_filename = f"{basepath}/benchmark_baseline.csv" if not os.path.isfile(csv_filename): solvers = { @@ -102,8 +105,12 @@ def test_baseline(args: Dict) -> None: def test_ml(args: Dict) -> None: basepath = args[""] - test_instances = [PickleGzInstance(f) for f in glob.glob(f"{basepath}/test/*.gz")] - train_instances = [PickleGzInstance(f) for f in glob.glob(f"{basepath}/train/*.gz")] + test_instances: List[Instance] = [ + PickleGzInstance(f) for f in glob.glob(f"{basepath}/test/*.gz") + ] + train_instances: List[Instance] = [ + PickleGzInstance(f) for f in glob.glob(f"{basepath}/train/*.gz") + ] csv_filename = f"{basepath}/benchmark_ml.csv" if not os.path.isfile(csv_filename): solvers = { diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index 15ea3a6..4d5303e 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -20,7 +20,7 @@ def test_benchmark() -> None: # Solve training instances training_solver = LearningSolver() - training_solver.parallel_solve(train_instances, n_jobs=n_jobs) + training_solver.parallel_solve(train_instances, n_jobs=n_jobs) # type: ignore # Benchmark test_solvers = { @@ -28,8 +28,12 @@ def test_benchmark() -> None: "Strategy B": LearningSolver(), } benchmark = BenchmarkRunner(test_solvers) - benchmark.fit(train_instances) - benchmark.parallel_solve(test_instances, n_jobs=n_jobs, n_trials=2) + benchmark.fit(train_instances) # type: ignore + benchmark.parallel_solve( + test_instances, # type: ignore + n_jobs=n_jobs, + n_trials=2, + ) assert benchmark.results.values.shape == (12, 20) benchmark.write_csv("/tmp/benchmark.csv")