mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Add types to remaining files; activate mypy's disallow_untyped_defs
This commit is contained in:
@@ -1,14 +1,16 @@
|
||||
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
||||
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
from typing import Dict, Tuple
|
||||
from unittest.mock import Mock
|
||||
|
||||
from miplearn.components.component import Component
|
||||
from miplearn.features import Features, TrainingSample
|
||||
from miplearn.instance.base import Instance
|
||||
|
||||
|
||||
def test_xy_instance():
|
||||
def _sample_xy(features, sample):
|
||||
def test_xy_instance() -> None:
|
||||
def _sample_xy(features: Features, sample: str) -> Tuple[Dict, Dict]:
|
||||
x = {
|
||||
"s1": {
|
||||
"category_a": [
|
||||
@@ -58,7 +60,7 @@ def test_xy_instance():
|
||||
instance_2 = Mock(spec=Instance)
|
||||
instance_2.training_data = ["s3"]
|
||||
instance_2.features = {}
|
||||
comp.sample_xy = _sample_xy
|
||||
comp.sample_xy = _sample_xy # type: ignore
|
||||
x_expected = {
|
||||
"category_a": [
|
||||
[1, 2, 3],
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
import logging
|
||||
from typing import Any, FrozenSet, Hashable
|
||||
from typing import Any, FrozenSet, Hashable, List
|
||||
|
||||
import gurobipy as gp
|
||||
import networkx as nx
|
||||
@@ -39,7 +39,7 @@ class GurobiStableSetProblem(Instance):
|
||||
return True
|
||||
|
||||
@overrides
|
||||
def find_violated_user_cuts(self, model):
|
||||
def find_violated_user_cuts(self, model: Any) -> List[FrozenSet]:
|
||||
assert isinstance(model, gp.Model)
|
||||
vals = model.cbGetNodeRel(model.getVars())
|
||||
violations = []
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
||||
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
from typing import cast
|
||||
from unittest.mock import Mock
|
||||
|
||||
import numpy as np
|
||||
@@ -179,9 +180,9 @@ def test_predict() -> None:
|
||||
}
|
||||
|
||||
|
||||
def test_fit_xy():
|
||||
def test_fit_xy() -> None:
|
||||
clf = Mock(spec=Classifier)
|
||||
clf.clone = lambda: Mock(spec=Classifier)
|
||||
clf.clone = lambda: Mock(spec=Classifier) # type: ignore
|
||||
thr = Mock(spec=Threshold)
|
||||
thr.clone = lambda: Mock(spec=Threshold)
|
||||
comp = PrimalSolutionComponent(classifier=clf, threshold=thr)
|
||||
@@ -197,17 +198,17 @@ def test_fit_xy():
|
||||
for category in ["type-a", "type-b"]:
|
||||
assert category in comp.classifiers
|
||||
assert category in comp.thresholds
|
||||
clf = comp.classifiers[category]
|
||||
clf = comp.classifiers[category] # type: ignore
|
||||
clf.fit.assert_called_once()
|
||||
assert_array_equal(x[category], clf.fit.call_args[0][0])
|
||||
assert_array_equal(y[category], clf.fit.call_args[0][1])
|
||||
thr = comp.thresholds[category]
|
||||
thr = comp.thresholds[category] # type: ignore
|
||||
thr.fit.assert_called_once()
|
||||
assert_array_equal(x[category], thr.fit.call_args[0][1])
|
||||
assert_array_equal(y[category], thr.fit.call_args[0][2])
|
||||
|
||||
|
||||
def test_usage():
|
||||
def test_usage() -> None:
|
||||
solver = LearningSolver(
|
||||
components=[
|
||||
PrimalSolutionComponent(),
|
||||
|
||||
Reference in New Issue
Block a user