Make get_variable_{categories,features} return np.ndarray

This commit is contained in:
2021-08-09 15:19:53 -05:00
parent 56b39b6c9c
commit 895cb962b6
13 changed files with 165 additions and 155 deletions

View File

@@ -63,7 +63,7 @@ class Instance(ABC):
"""
return np.zeros(1)
def get_variable_features(self) -> Dict[str, List[float]]:
def get_variable_features(self, names: np.ndarray) -> np.ndarray:
"""
Returns dictionary mapping the name of each variable to a (1-dimensional) list
of numerical features describing a particular decision variable.
@@ -81,11 +81,11 @@ class Instance(ABC):
If features are not provided for a given variable, MIPLearn will use a
default set of features.
By default, returns {}.
By default, returns [[0.0], ..., [0.0]].
"""
return {}
return np.zeros((len(names), 1))
def get_variable_categories(self) -> Dict[str, str]:
def get_variable_categories(self, names: np.ndarray) -> np.ndarray:
"""
Returns a dictionary mapping the name of each variable to its category.
@@ -93,9 +93,9 @@ class Instance(ABC):
internal ML model to predict the values of both variables. If a variable is not
listed in the dictionary, ML models will ignore the variable.
By default, returns {}.
By default, returns `names`.
"""
return {}
return names
def get_constraint_features(self) -> Dict[str, List[float]]:
return {}

View File

@@ -36,14 +36,14 @@ class FileInstance(Instance):
return self.instance.get_instance_features()
@overrides
def get_variable_features(self) -> Dict[str, List[float]]:
def get_variable_features(self, names: np.ndarray) -> np.ndarray:
assert self.instance is not None
return self.instance.get_variable_features()
return self.instance.get_variable_features(names)
@overrides
def get_variable_categories(self) -> Dict[str, str]:
def get_variable_categories(self, names: np.ndarray) -> np.ndarray:
assert self.instance is not None
return self.instance.get_variable_categories()
return self.instance.get_variable_categories(names)
@overrides
def get_constraint_features(self) -> Dict[str, List[float]]:

View File

@@ -48,14 +48,14 @@ class PickleGzInstance(Instance):
return self.instance.get_instance_features()
@overrides
def get_variable_features(self) -> Dict[str, List[float]]:
def get_variable_features(self, names: np.ndarray) -> np.ndarray:
assert self.instance is not None
return self.instance.get_variable_features()
return self.instance.get_variable_features(names)
@overrides
def get_variable_categories(self) -> Dict[str, str]:
def get_variable_categories(self, names: np.ndarray) -> np.ndarray:
assert self.instance is not None
return self.instance.get_variable_categories()
return self.instance.get_variable_categories(names)
@overrides
def get_constraint_features(self) -> Dict[str, List[float]]: