mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Replace Hashable by str
This commit is contained in:
@@ -4,7 +4,7 @@
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Hashable, TYPE_CHECKING, Dict
|
||||
from typing import Any, List, TYPE_CHECKING, Dict
|
||||
|
||||
from miplearn.features.sample import Sample
|
||||
|
||||
@@ -83,7 +83,7 @@ class Instance(ABC):
|
||||
"""
|
||||
return {}
|
||||
|
||||
def get_variable_categories(self) -> Dict[str, Hashable]:
|
||||
def get_variable_categories(self) -> Dict[str, str]:
|
||||
"""
|
||||
Returns a dictionary mapping the name of each variable to its category.
|
||||
|
||||
@@ -91,7 +91,6 @@ class Instance(ABC):
|
||||
internal ML model to predict the values of both variables. If a variable is not
|
||||
listed in the dictionary, ML models will ignore the variable.
|
||||
|
||||
A category can be any hashable type, such as strings, numbers or tuples.
|
||||
By default, returns {}.
|
||||
"""
|
||||
return {}
|
||||
@@ -99,7 +98,7 @@ class Instance(ABC):
|
||||
def get_constraint_features(self) -> Dict[str, List[float]]:
|
||||
return {}
|
||||
|
||||
def get_constraint_categories(self) -> Dict[str, Hashable]:
|
||||
def get_constraint_categories(self) -> Dict[str, str]:
|
||||
return {}
|
||||
|
||||
def has_static_lazy_constraints(self) -> bool:
|
||||
@@ -115,7 +114,7 @@ class Instance(ABC):
|
||||
self,
|
||||
solver: "InternalSolver",
|
||||
model: Any,
|
||||
) -> List[Hashable]:
|
||||
) -> List[str]:
|
||||
"""
|
||||
Returns lazy constraint violations found for the current solution.
|
||||
|
||||
@@ -125,10 +124,10 @@ class Instance(ABC):
|
||||
resolve the problem. The process repeats until no further lazy constraint
|
||||
violations are found.
|
||||
|
||||
Each "violation" is simply a string, a tuple or any other hashable type which
|
||||
allows the instance to identify unambiguously which lazy constraint should be
|
||||
generated. In the Traveling Salesman Problem, for example, a subtour
|
||||
violation could be a frozen set containing the cities in the subtour.
|
||||
Each "violation" is simply a string which allows the instance to identify
|
||||
unambiguously which lazy constraint should be generated. In the Traveling
|
||||
Salesman Problem, for example, a subtour violation could be a string
|
||||
containing the cities in the subtour.
|
||||
|
||||
The current solution can be queried with `solver.get_solution()`. If the solver
|
||||
is configured to use lazy callbacks, this solution may be non-integer.
|
||||
@@ -141,7 +140,7 @@ class Instance(ABC):
|
||||
self,
|
||||
solver: "InternalSolver",
|
||||
model: Any,
|
||||
violation: Hashable,
|
||||
violation: str,
|
||||
) -> None:
|
||||
"""
|
||||
Adds constraints to the model to ensure that the given violation is fixed.
|
||||
@@ -167,14 +166,14 @@ class Instance(ABC):
|
||||
def has_user_cuts(self) -> bool:
|
||||
return False
|
||||
|
||||
def find_violated_user_cuts(self, model: Any) -> List[Hashable]:
|
||||
def find_violated_user_cuts(self, model: Any) -> List[str]:
|
||||
return []
|
||||
|
||||
def enforce_user_cut(
|
||||
self,
|
||||
solver: "InternalSolver",
|
||||
model: Any,
|
||||
violation: Hashable,
|
||||
violation: str,
|
||||
) -> Any:
|
||||
return None
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ import gc
|
||||
import gzip
|
||||
import os
|
||||
import pickle
|
||||
from typing import Optional, Any, List, Hashable, cast, IO, TYPE_CHECKING, Dict
|
||||
from typing import Optional, Any, List, cast, IO, TYPE_CHECKING, Dict
|
||||
|
||||
from overrides import overrides
|
||||
|
||||
@@ -52,7 +52,7 @@ class PickleGzInstance(Instance):
|
||||
return self.instance.get_variable_features()
|
||||
|
||||
@overrides
|
||||
def get_variable_categories(self) -> Dict[str, Hashable]:
|
||||
def get_variable_categories(self) -> Dict[str, str]:
|
||||
assert self.instance is not None
|
||||
return self.instance.get_variable_categories()
|
||||
|
||||
@@ -62,7 +62,7 @@ class PickleGzInstance(Instance):
|
||||
return self.instance.get_constraint_features()
|
||||
|
||||
@overrides
|
||||
def get_constraint_categories(self) -> Dict[str, Hashable]:
|
||||
def get_constraint_categories(self) -> Dict[str, str]:
|
||||
assert self.instance is not None
|
||||
return self.instance.get_constraint_categories()
|
||||
|
||||
@@ -86,7 +86,7 @@ class PickleGzInstance(Instance):
|
||||
self,
|
||||
solver: "InternalSolver",
|
||||
model: Any,
|
||||
) -> List[Hashable]:
|
||||
) -> List[str]:
|
||||
assert self.instance is not None
|
||||
return self.instance.find_violated_lazy_constraints(solver, model)
|
||||
|
||||
@@ -95,13 +95,13 @@ class PickleGzInstance(Instance):
|
||||
self,
|
||||
solver: "InternalSolver",
|
||||
model: Any,
|
||||
violation: Hashable,
|
||||
violation: str,
|
||||
) -> None:
|
||||
assert self.instance is not None
|
||||
self.instance.enforce_lazy_constraint(solver, model, violation)
|
||||
|
||||
@overrides
|
||||
def find_violated_user_cuts(self, model: Any) -> List[Hashable]:
|
||||
def find_violated_user_cuts(self, model: Any) -> List[str]:
|
||||
assert self.instance is not None
|
||||
return self.instance.find_violated_user_cuts(model)
|
||||
|
||||
@@ -110,7 +110,7 @@ class PickleGzInstance(Instance):
|
||||
self,
|
||||
solver: "InternalSolver",
|
||||
model: Any,
|
||||
violation: Hashable,
|
||||
violation: str,
|
||||
) -> None:
|
||||
assert self.instance is not None
|
||||
self.instance.enforce_user_cut(solver, model, violation)
|
||||
|
||||
Reference in New Issue
Block a user