mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Disallow untyped calls and incomplete defs
This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
import gzip
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -20,7 +21,7 @@ class Instance(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def to_model(self):
|
||||
def to_model(self) -> Any:
|
||||
"""
|
||||
Returns a concrete Pyomo model corresponding to this instance.
|
||||
"""
|
||||
|
||||
@@ -4,19 +4,20 @@
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from typing import Any, List
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RedirectOutput:
|
||||
def __init__(self, streams):
|
||||
def __init__(self, streams: List[Any]):
|
||||
self.streams = streams
|
||||
|
||||
def write(self, data):
|
||||
def write(self, data: Any) -> None:
|
||||
for stream in self.streams:
|
||||
stream.write(data)
|
||||
|
||||
def flush(self):
|
||||
def flush(self) -> None:
|
||||
for stream in self.streams:
|
||||
stream.flush()
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ import re
|
||||
import sys
|
||||
from io import StringIO
|
||||
from random import randint
|
||||
from typing import List, Any, Dict, Union
|
||||
from typing import List, Any, Dict, Union, Tuple, Optional
|
||||
|
||||
from . import RedirectOutput
|
||||
from .internal import (
|
||||
@@ -73,13 +73,14 @@ class GurobiSolver(InternalSolver):
|
||||
self.model.update()
|
||||
self._update_vars()
|
||||
|
||||
def _raise_if_callback(self):
|
||||
def _raise_if_callback(self) -> None:
|
||||
if self.cb_where is not None:
|
||||
raise Exception("method cannot be called from a callback")
|
||||
|
||||
def _update_vars(self):
|
||||
def _update_vars(self) -> None:
|
||||
self._all_vars = {}
|
||||
self._bin_vars = {}
|
||||
idx: Union[Tuple, List[int], int]
|
||||
for var in self.model.getVars():
|
||||
m = re.search(r"([^[]*)\[(.*)]", var.varName)
|
||||
if m is None:
|
||||
@@ -100,7 +101,7 @@ class GurobiSolver(InternalSolver):
|
||||
self._bin_vars[name] = {}
|
||||
self._bin_vars[name][idx] = var
|
||||
|
||||
def _apply_params(self, streams):
|
||||
def _apply_params(self, streams: List[Any]) -> None:
|
||||
with RedirectOutput(streams):
|
||||
for (name, value) in self.params.items():
|
||||
self.model.setParam(name, value)
|
||||
@@ -271,7 +272,7 @@ class GurobiSolver(InternalSolver):
|
||||
else:
|
||||
self.model.addConstr(constraint, name=name)
|
||||
|
||||
def _clear_warm_start(self):
|
||||
def _clear_warm_start(self) -> None:
|
||||
for (varname, vardict) in self._all_vars.items():
|
||||
for (idx, var) in vardict.items():
|
||||
var.start = self.GRB.UNDEFINED
|
||||
@@ -338,14 +339,18 @@ class GurobiSolver(InternalSolver):
|
||||
self.model = self.model.relax()
|
||||
self._update_vars()
|
||||
|
||||
def _extract_warm_start_value(self, log):
|
||||
def _extract_warm_start_value(self, log: str) -> Optional[float]:
|
||||
ws = self.__extract(log, "MIP start with objective ([0-9.e+-]*)")
|
||||
if ws is not None:
|
||||
ws = float(ws)
|
||||
return ws
|
||||
if ws is None:
|
||||
return None
|
||||
return float(ws)
|
||||
|
||||
@staticmethod
|
||||
def __extract(log, regexp, default=None):
|
||||
def __extract(
|
||||
log: str,
|
||||
regexp: str,
|
||||
default: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
value = default
|
||||
for line in log.splitlines():
|
||||
matches = re.findall(regexp, line)
|
||||
|
||||
@@ -192,7 +192,7 @@ class InternalSolver(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_constraint(self, cobj: Constraint):
|
||||
def add_constraint(self, cobj: Constraint) -> None:
|
||||
"""
|
||||
Adds a single constraint to the model.
|
||||
"""
|
||||
@@ -209,7 +209,7 @@ class InternalSolver(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def is_constraint_satisfied(self, cobj: Constraint):
|
||||
def is_constraint_satisfied(self, cobj: Constraint) -> bool:
|
||||
"""
|
||||
Returns True if the current solution satisfies the given constraint.
|
||||
"""
|
||||
|
||||
@@ -6,7 +6,7 @@ import logging
|
||||
import re
|
||||
import sys
|
||||
from io import StringIO
|
||||
from typing import Any, List, Dict
|
||||
from typing import Any, List, Dict, Optional
|
||||
|
||||
import pyomo
|
||||
from pyomo import environ as pe
|
||||
@@ -169,18 +169,18 @@ class BasePyomoSolver(InternalSolver):
|
||||
variables[str(var)] += [index]
|
||||
return variables
|
||||
|
||||
def _clear_warm_start(self):
|
||||
def _clear_warm_start(self) -> None:
|
||||
for var in self._all_vars:
|
||||
if not var.fixed:
|
||||
var.value = None
|
||||
self._is_warm_start_available = False
|
||||
|
||||
def _update_obj(self):
|
||||
def _update_obj(self) -> None:
|
||||
self._obj_sense = "max"
|
||||
if self._pyomo_solver._objective.sense == pyomo.core.kernel.objective.minimize:
|
||||
self._obj_sense = "min"
|
||||
|
||||
def _update_vars(self):
|
||||
def _update_vars(self) -> None:
|
||||
self._all_vars = []
|
||||
self._bin_vars = []
|
||||
self._varname_to_var = {}
|
||||
@@ -191,7 +191,7 @@ class BasePyomoSolver(InternalSolver):
|
||||
if var[idx].domain == pyomo.core.base.set_types.Binary:
|
||||
self._bin_vars += [var[idx]]
|
||||
|
||||
def _update_constrs(self):
|
||||
def _update_constrs(self) -> None:
|
||||
self._cname_to_constr = {}
|
||||
for constr in self.model.component_objects(Constraint):
|
||||
self._cname_to_constr[constr.name] = constr
|
||||
@@ -220,7 +220,11 @@ class BasePyomoSolver(InternalSolver):
|
||||
self._update_constrs()
|
||||
|
||||
@staticmethod
|
||||
def __extract(log, regexp, default=None):
|
||||
def __extract(
|
||||
log: str,
|
||||
regexp: Optional[str],
|
||||
default: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
if regexp is None:
|
||||
return default
|
||||
value = default
|
||||
@@ -231,22 +235,25 @@ class BasePyomoSolver(InternalSolver):
|
||||
value = matches[0]
|
||||
return value
|
||||
|
||||
def _extract_warm_start_value(self, log):
|
||||
def _extract_warm_start_value(self, log: str) -> Optional[float]:
|
||||
value = self.__extract(log, self._get_warm_start_regexp())
|
||||
if value is not None:
|
||||
value = float(value)
|
||||
return value
|
||||
if value is None:
|
||||
return None
|
||||
return float(value)
|
||||
|
||||
def _extract_node_count(self, log):
|
||||
return self.__extract(log, self._get_node_count_regexp())
|
||||
def _extract_node_count(self, log: str) -> Optional[int]:
|
||||
value = self.__extract(log, self._get_node_count_regexp())
|
||||
if value is None:
|
||||
return None
|
||||
return int(value)
|
||||
|
||||
def get_constraint_ids(self):
|
||||
return list(self._cname_to_constr.keys())
|
||||
|
||||
def _get_warm_start_regexp(self):
|
||||
def _get_warm_start_regexp(self) -> Optional[str]:
|
||||
return None
|
||||
|
||||
def _get_node_count_regexp(self):
|
||||
def _get_node_count_regexp(self) -> Optional[str]:
|
||||
return None
|
||||
|
||||
def extract_constraint(self, cid):
|
||||
|
||||
Reference in New Issue
Block a user