Always remove .mypy_cache; fix more mypy tests

This commit is contained in:
2021-04-09 08:18:54 -05:00
parent 32b6a8f3fa
commit 3f4336f902
3 changed files with 20 additions and 8 deletions

View File

@@ -42,6 +42,7 @@ reformat:
$(PYTHON) -m black . $(PYTHON) -m black .
test: test:
rm -rf .mypy_cache
$(MYPY) -p miplearn $(MYPY) -p miplearn
$(MYPY) -p tests $(MYPY) -p tests
$(MYPY) -p benchmark $(MYPY) -p benchmark

View File

@@ -24,7 +24,7 @@ import importlib
import logging import logging
import os import os
from pathlib import Path from pathlib import Path
from typing import Dict from typing import Dict, List
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import pandas as pd import pandas as pd
@@ -39,6 +39,7 @@ from miplearn import (
setup_logger, setup_logger,
PickleGzInstance, PickleGzInstance,
write_pickle_gz_multiple, write_pickle_gz_multiple,
Instance,
) )
setup_logger() setup_logger()
@@ -59,7 +60,7 @@ def train(args: Dict) -> None:
done_filename = f"{basepath}/train/done" done_filename = f"{basepath}/train/done"
if not os.path.isfile(done_filename): if not os.path.isfile(done_filename):
train_instances = [ train_instances: List[Instance] = [
PickleGzInstance(f) for f in glob.glob(f"{basepath}/train/*.gz") PickleGzInstance(f) for f in glob.glob(f"{basepath}/train/*.gz")
] ]
solver = LearningSolver( solver = LearningSolver(
@@ -79,7 +80,9 @@ def train(args: Dict) -> None:
def test_baseline(args: Dict) -> None: def test_baseline(args: Dict) -> None:
basepath = args["<challenge>"] basepath = args["<challenge>"]
test_instances = [PickleGzInstance(f) for f in glob.glob(f"{basepath}/test/*.gz")] test_instances: List[Instance] = [
PickleGzInstance(f) for f in glob.glob(f"{basepath}/test/*.gz")
]
csv_filename = f"{basepath}/benchmark_baseline.csv" csv_filename = f"{basepath}/benchmark_baseline.csv"
if not os.path.isfile(csv_filename): if not os.path.isfile(csv_filename):
solvers = { solvers = {
@@ -102,8 +105,12 @@ def test_baseline(args: Dict) -> None:
def test_ml(args: Dict) -> None: def test_ml(args: Dict) -> None:
basepath = args["<challenge>"] basepath = args["<challenge>"]
test_instances = [PickleGzInstance(f) for f in glob.glob(f"{basepath}/test/*.gz")] test_instances: List[Instance] = [
train_instances = [PickleGzInstance(f) for f in glob.glob(f"{basepath}/train/*.gz")] PickleGzInstance(f) for f in glob.glob(f"{basepath}/test/*.gz")
]
train_instances: List[Instance] = [
PickleGzInstance(f) for f in glob.glob(f"{basepath}/train/*.gz")
]
csv_filename = f"{basepath}/benchmark_ml.csv" csv_filename = f"{basepath}/benchmark_ml.csv"
if not os.path.isfile(csv_filename): if not os.path.isfile(csv_filename):
solvers = { solvers = {

View File

@@ -20,7 +20,7 @@ def test_benchmark() -> None:
# Solve training instances # Solve training instances
training_solver = LearningSolver() training_solver = LearningSolver()
training_solver.parallel_solve(train_instances, n_jobs=n_jobs) training_solver.parallel_solve(train_instances, n_jobs=n_jobs) # type: ignore
# Benchmark # Benchmark
test_solvers = { test_solvers = {
@@ -28,8 +28,12 @@ def test_benchmark() -> None:
"Strategy B": LearningSolver(), "Strategy B": LearningSolver(),
} }
benchmark = BenchmarkRunner(test_solvers) benchmark = BenchmarkRunner(test_solvers)
benchmark.fit(train_instances) benchmark.fit(train_instances) # type: ignore
benchmark.parallel_solve(test_instances, n_jobs=n_jobs, n_trials=2) benchmark.parallel_solve(
test_instances, # type: ignore
n_jobs=n_jobs,
n_trials=2,
)
assert benchmark.results.values.shape == (12, 20) assert benchmark.results.values.shape == (12, 20)
benchmark.write_csv("/tmp/benchmark.csv") benchmark.write_csv("/tmp/benchmark.csv")