Add type annotations to components

This commit is contained in:
2021-01-21 15:54:23 -06:00
parent a98a783969
commit fc0835e694
12 changed files with 122 additions and 76 deletions

View File

@@ -2,8 +2,16 @@
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details.
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, List, Union, TYPE_CHECKING
from miplearn.instance import Instance
from miplearn.types import MIPSolveStats, TrainingSample
if TYPE_CHECKING:
from miplearn.solvers.learning import LearningSolver
class Component(ABC):
@@ -15,18 +23,35 @@ class Component(ABC):
strategy.
"""
def before_solve(self, solver, instance, model):
def before_solve(
self,
solver: LearningSolver,
instance: Instance,
model: Any,
) -> None:
"""
Method called by LearningSolver before the problem is solved.
Parameters
----------
solver
The solver calling this method.
instance
The instance being solved.
model
The concrete optimization model being solved.
"""
return
@abstractmethod
def after_solve(
self,
solver,
instance,
model,
stats,
training_data,
):
solver: LearningSolver,
instance: Instance,
model: Any,
stats: MIPSolveStats,
training_data: TrainingSample,
) -> None:
"""
Method called by LearningSolver after the problem is solved to optimality.
@@ -40,19 +65,23 @@ class Component(ABC):
The concrete optimization model being solved.
stats: dict
A dictionary containing statistics about the solution process, such as
number of nodes explored and running time. Components are free to add their own
statistics here. For example, PrimalSolutionComponent adds statistics regarding
the number of predicted variables. All statistics in this dictionary are exported
to the benchmark CSV file.
number of nodes explored and running time. Components are free to add
their own statistics here. For example, PrimalSolutionComponent adds
statistics regarding the number of predicted variables. All statistics in
this dictionary are exported to the benchmark CSV file.
training_data: dict
A dictionary containing data that may be useful for training machine learning
models and accelerating the solution process. Components are free to add their
own training data here. For example, PrimalSolutionComponent adds the current
primal solution. The data must be pickable.
A dictionary containing data that may be useful for training machine
learning models and accelerating the solution process. Components are
free to add their own training data here. For example,
PrimalSolutionComponent adds the current primal solution. The data must
be pickable.
"""
pass
def fit(self, training_instances):
def fit(
self,
training_instances: Union[List[str], List[Instance]],
) -> None:
return
def iteration_cb(self, solver, instance, model):