mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Make get_variable_{categories,features} return np.ndarray
This commit is contained in:
@@ -63,7 +63,7 @@ class Instance(ABC):
|
||||
"""
|
||||
return np.zeros(1)
|
||||
|
||||
def get_variable_features(self) -> Dict[str, List[float]]:
|
||||
def get_variable_features(self, names: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Returns dictionary mapping the name of each variable to a (1-dimensional) list
|
||||
of numerical features describing a particular decision variable.
|
||||
@@ -81,11 +81,11 @@ class Instance(ABC):
|
||||
If features are not provided for a given variable, MIPLearn will use a
|
||||
default set of features.
|
||||
|
||||
By default, returns {}.
|
||||
By default, returns [[0.0], ..., [0.0]].
|
||||
"""
|
||||
return {}
|
||||
return np.zeros((len(names), 1))
|
||||
|
||||
def get_variable_categories(self) -> Dict[str, str]:
|
||||
def get_variable_categories(self, names: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Returns a dictionary mapping the name of each variable to its category.
|
||||
|
||||
@@ -93,9 +93,9 @@ class Instance(ABC):
|
||||
internal ML model to predict the values of both variables. If a variable is not
|
||||
listed in the dictionary, ML models will ignore the variable.
|
||||
|
||||
By default, returns {}.
|
||||
By default, returns `names`.
|
||||
"""
|
||||
return {}
|
||||
return names
|
||||
|
||||
def get_constraint_features(self) -> Dict[str, List[float]]:
|
||||
return {}
|
||||
|
||||
@@ -36,14 +36,14 @@ class FileInstance(Instance):
|
||||
return self.instance.get_instance_features()
|
||||
|
||||
@overrides
|
||||
def get_variable_features(self) -> Dict[str, List[float]]:
|
||||
def get_variable_features(self, names: np.ndarray) -> np.ndarray:
|
||||
assert self.instance is not None
|
||||
return self.instance.get_variable_features()
|
||||
return self.instance.get_variable_features(names)
|
||||
|
||||
@overrides
|
||||
def get_variable_categories(self) -> Dict[str, str]:
|
||||
def get_variable_categories(self, names: np.ndarray) -> np.ndarray:
|
||||
assert self.instance is not None
|
||||
return self.instance.get_variable_categories()
|
||||
return self.instance.get_variable_categories(names)
|
||||
|
||||
@overrides
|
||||
def get_constraint_features(self) -> Dict[str, List[float]]:
|
||||
|
||||
@@ -48,14 +48,14 @@ class PickleGzInstance(Instance):
|
||||
return self.instance.get_instance_features()
|
||||
|
||||
@overrides
|
||||
def get_variable_features(self) -> Dict[str, List[float]]:
|
||||
def get_variable_features(self, names: np.ndarray) -> np.ndarray:
|
||||
assert self.instance is not None
|
||||
return self.instance.get_variable_features()
|
||||
return self.instance.get_variable_features(names)
|
||||
|
||||
@overrides
|
||||
def get_variable_categories(self) -> Dict[str, str]:
|
||||
def get_variable_categories(self, names: np.ndarray) -> np.ndarray:
|
||||
assert self.instance is not None
|
||||
return self.instance.get_variable_categories()
|
||||
return self.instance.get_variable_categories(names)
|
||||
|
||||
@overrides
|
||||
def get_constraint_features(self) -> Dict[str, List[float]]:
|
||||
|
||||
Reference in New Issue
Block a user