Add types to remaining files; activate mypy's disallow_untyped_defs

This commit is contained in:
2021-04-07 21:25:30 -05:00
parent f5606efb72
commit e9cd6d1715
21 changed files with 102 additions and 64 deletions

View File

@@ -1,14 +1,16 @@
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details.
from typing import Dict, Tuple
from unittest.mock import Mock
from miplearn.components.component import Component
from miplearn.features import Features, TrainingSample
from miplearn.instance.base import Instance
def test_xy_instance():
def _sample_xy(features, sample):
def test_xy_instance() -> None:
def _sample_xy(features: Features, sample: str) -> Tuple[Dict, Dict]:
x = {
"s1": {
"category_a": [
@@ -58,7 +60,7 @@ def test_xy_instance():
instance_2 = Mock(spec=Instance)
instance_2.training_data = ["s3"]
instance_2.features = {}
comp.sample_xy = _sample_xy
comp.sample_xy = _sample_xy # type: ignore
x_expected = {
"category_a": [
[1, 2, 3],

View File

@@ -3,7 +3,7 @@
# Released under the modified BSD license. See COPYING.md for more details.
import logging
from typing import Any, FrozenSet, Hashable
from typing import Any, FrozenSet, Hashable, List
import gurobipy as gp
import networkx as nx
@@ -39,7 +39,7 @@ class GurobiStableSetProblem(Instance):
return True
@overrides
def find_violated_user_cuts(self, model):
def find_violated_user_cuts(self, model: Any) -> List[FrozenSet]:
assert isinstance(model, gp.Model)
vals = model.cbGetNodeRel(model.getVars())
violations = []

View File

@@ -1,6 +1,7 @@
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details.
from typing import cast
from unittest.mock import Mock
import numpy as np
@@ -179,9 +180,9 @@ def test_predict() -> None:
}
def test_fit_xy():
def test_fit_xy() -> None:
clf = Mock(spec=Classifier)
clf.clone = lambda: Mock(spec=Classifier)
clf.clone = lambda: Mock(spec=Classifier) # type: ignore
thr = Mock(spec=Threshold)
thr.clone = lambda: Mock(spec=Threshold)
comp = PrimalSolutionComponent(classifier=clf, threshold=thr)
@@ -197,17 +198,17 @@ def test_fit_xy():
for category in ["type-a", "type-b"]:
assert category in comp.classifiers
assert category in comp.thresholds
clf = comp.classifiers[category]
clf = comp.classifiers[category] # type: ignore
clf.fit.assert_called_once()
assert_array_equal(x[category], clf.fit.call_args[0][0])
assert_array_equal(y[category], clf.fit.call_args[0][1])
thr = comp.thresholds[category]
thr = comp.thresholds[category] # type: ignore
thr.fit.assert_called_once()
assert_array_equal(x[category], thr.fit.call_args[0][1])
assert_array_equal(y[category], thr.fit.call_args[0][2])
def test_usage():
def test_usage() -> None:
solver = LearningSolver(
components=[
PrimalSolutionComponent(),