mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Request variable features/categories in bulk
This commit is contained in:
@@ -4,10 +4,11 @@
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Optional, Hashable, TYPE_CHECKING
|
||||
from typing import Any, List, Optional, Hashable, TYPE_CHECKING, Dict
|
||||
|
||||
from overrides import EnforceOverrides
|
||||
|
||||
from miplearn.features import Sample
|
||||
from miplearn.types import VariableName, Category
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -16,7 +17,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
# noinspection PyMethodMayBeStatic
|
||||
class Instance(ABC):
|
||||
class Instance(ABC, EnforceOverrides):
|
||||
"""
|
||||
Abstract class holding all the data necessary to generate a concrete model of the
|
||||
proble.
|
||||
@@ -62,10 +63,10 @@ class Instance(ABC):
|
||||
"""
|
||||
return [0.0]
|
||||
|
||||
def get_variable_features(self, var_name: VariableName) -> List[float]:
|
||||
def get_variable_features(self) -> Dict[str, List[float]]:
|
||||
"""
|
||||
Returns a (1-dimensional) list of numerical features describing a particular
|
||||
decision variable.
|
||||
Returns dictionary mapping the name of each variable to a (1-dimensional) list
|
||||
of numerical features describing a particular decision variable.
|
||||
|
||||
In combination with instance features, variable features are used by
|
||||
LearningSolver to predict, among other things, the optimal value of each
|
||||
@@ -77,22 +78,25 @@ class Instance(ABC):
|
||||
length for all variables within the same category, for all relevant instances
|
||||
of the problem.
|
||||
|
||||
By default, returns [0.0].
|
||||
"""
|
||||
return [0.0]
|
||||
If features are not provided for a given variable, MIPLearn will use a
|
||||
default set of features.
|
||||
|
||||
def get_variable_category(self, var_name: VariableName) -> Optional[Category]:
|
||||
By default, returns {}.
|
||||
"""
|
||||
Returns the category for each decision variable.
|
||||
return {}
|
||||
|
||||
def get_variable_categories(self) -> Dict[str, Hashable]:
|
||||
"""
|
||||
Returns a dictionary mapping the name of each variable to its category.
|
||||
|
||||
If two variables have the same category, LearningSolver will use the same
|
||||
internal ML model to predict the values of both variables. If the returned
|
||||
category is None, ML models will ignore the variable.
|
||||
internal ML model to predict the values of both variables. If a variable is not
|
||||
listed in the dictionary, ML models will ignore the variable.
|
||||
|
||||
A category can be any hashable type, such as strings, numbers or tuples.
|
||||
By default, returns "default".
|
||||
By default, returns {}.
|
||||
"""
|
||||
return "default"
|
||||
return {}
|
||||
|
||||
def get_constraint_features(self, cid: str) -> List[float]:
|
||||
return [0.0]
|
||||
|
||||
@@ -6,12 +6,11 @@ import gc
|
||||
import gzip
|
||||
import os
|
||||
import pickle
|
||||
from typing import Optional, Any, List, Hashable, cast, IO, TYPE_CHECKING
|
||||
from typing import Optional, Any, List, Hashable, cast, IO, TYPE_CHECKING, Dict
|
||||
|
||||
from overrides import overrides
|
||||
|
||||
from miplearn.instance.base import logger, Instance
|
||||
from miplearn.types import VariableName, Category
|
||||
from miplearn.instance.base import Instance
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from miplearn.solvers.learning import InternalSolver
|
||||
@@ -47,14 +46,14 @@ class PickleGzInstance(Instance):
|
||||
return self.instance.get_instance_features()
|
||||
|
||||
@overrides
|
||||
def get_variable_features(self, var_name: VariableName) -> List[float]:
|
||||
def get_variable_features(self) -> Dict[str, List[float]]:
|
||||
assert self.instance is not None
|
||||
return self.instance.get_variable_features(var_name)
|
||||
return self.instance.get_variable_features()
|
||||
|
||||
@overrides
|
||||
def get_variable_category(self, var_name: VariableName) -> Optional[Category]:
|
||||
def get_variable_categories(self) -> Dict[str, Hashable]:
|
||||
assert self.instance is not None
|
||||
return self.instance.get_variable_category(var_name)
|
||||
return self.instance.get_variable_categories()
|
||||
|
||||
@overrides
|
||||
def get_constraint_features(self, cid: str) -> Optional[List[float]]:
|
||||
|
||||
Reference in New Issue
Block a user