Always remove .mypy_cache; fix more mypy tests

master
Alinson S. Xavier 5 years ago
parent 32b6a8f3fa
commit 3f4336f902
No known key found for this signature in database
GPG Key ID: DCA0DAD4D2F58624

@ -42,6 +42,7 @@ reformat:
$(PYTHON) -m black . $(PYTHON) -m black .
test: test:
rm -rf .mypy_cache
$(MYPY) -p miplearn $(MYPY) -p miplearn
$(MYPY) -p tests $(MYPY) -p tests
$(MYPY) -p benchmark $(MYPY) -p benchmark

@ -24,7 +24,7 @@ import importlib
import logging import logging
import os import os
from pathlib import Path from pathlib import Path
from typing import Dict from typing import Dict, List
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import pandas as pd import pandas as pd
@ -39,6 +39,7 @@ from miplearn import (
setup_logger, setup_logger,
PickleGzInstance, PickleGzInstance,
write_pickle_gz_multiple, write_pickle_gz_multiple,
Instance,
) )
setup_logger() setup_logger()
@ -59,7 +60,7 @@ def train(args: Dict) -> None:
done_filename = f"{basepath}/train/done" done_filename = f"{basepath}/train/done"
if not os.path.isfile(done_filename): if not os.path.isfile(done_filename):
train_instances = [ train_instances: List[Instance] = [
PickleGzInstance(f) for f in glob.glob(f"{basepath}/train/*.gz") PickleGzInstance(f) for f in glob.glob(f"{basepath}/train/*.gz")
] ]
solver = LearningSolver( solver = LearningSolver(
@ -79,7 +80,9 @@ def train(args: Dict) -> None:
def test_baseline(args: Dict) -> None: def test_baseline(args: Dict) -> None:
basepath = args["<challenge>"] basepath = args["<challenge>"]
test_instances = [PickleGzInstance(f) for f in glob.glob(f"{basepath}/test/*.gz")] test_instances: List[Instance] = [
PickleGzInstance(f) for f in glob.glob(f"{basepath}/test/*.gz")
]
csv_filename = f"{basepath}/benchmark_baseline.csv" csv_filename = f"{basepath}/benchmark_baseline.csv"
if not os.path.isfile(csv_filename): if not os.path.isfile(csv_filename):
solvers = { solvers = {
@ -102,8 +105,12 @@ def test_baseline(args: Dict) -> None:
def test_ml(args: Dict) -> None: def test_ml(args: Dict) -> None:
basepath = args["<challenge>"] basepath = args["<challenge>"]
test_instances = [PickleGzInstance(f) for f in glob.glob(f"{basepath}/test/*.gz")] test_instances: List[Instance] = [
train_instances = [PickleGzInstance(f) for f in glob.glob(f"{basepath}/train/*.gz")] PickleGzInstance(f) for f in glob.glob(f"{basepath}/test/*.gz")
]
train_instances: List[Instance] = [
PickleGzInstance(f) for f in glob.glob(f"{basepath}/train/*.gz")
]
csv_filename = f"{basepath}/benchmark_ml.csv" csv_filename = f"{basepath}/benchmark_ml.csv"
if not os.path.isfile(csv_filename): if not os.path.isfile(csv_filename):
solvers = { solvers = {

@ -20,7 +20,7 @@ def test_benchmark() -> None:
# Solve training instances # Solve training instances
training_solver = LearningSolver() training_solver = LearningSolver()
training_solver.parallel_solve(train_instances, n_jobs=n_jobs) training_solver.parallel_solve(train_instances, n_jobs=n_jobs) # type: ignore
# Benchmark # Benchmark
test_solvers = { test_solvers = {
@ -28,8 +28,12 @@ def test_benchmark() -> None:
"Strategy B": LearningSolver(), "Strategy B": LearningSolver(),
} }
benchmark = BenchmarkRunner(test_solvers) benchmark = BenchmarkRunner(test_solvers)
benchmark.fit(train_instances) benchmark.fit(train_instances) # type: ignore
benchmark.parallel_solve(test_instances, n_jobs=n_jobs, n_trials=2) benchmark.parallel_solve(
test_instances, # type: ignore
n_jobs=n_jobs,
n_trials=2,
)
assert benchmark.results.values.shape == (12, 20) assert benchmark.results.values.shape == (12, 20)
benchmark.write_csv("/tmp/benchmark.csv") benchmark.write_csv("/tmp/benchmark.csv")

Loading…
Cancel
Save