mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Make get_variable_{categories,features} return np.ndarray
This commit is contained in:
@@ -605,6 +605,7 @@ class PyomoTestInstanceKnapsack(Instance):
|
||||
self.weights = weights
|
||||
self.prices = prices
|
||||
self.capacity = capacity
|
||||
self.n = len(weights)
|
||||
|
||||
@overrides
|
||||
def to_model(self) -> pe.ConcreteModel:
|
||||
@@ -631,15 +632,17 @@ class PyomoTestInstanceKnapsack(Instance):
|
||||
)
|
||||
|
||||
@overrides
|
||||
def get_variable_features(self) -> Dict[str, List[float]]:
|
||||
return {
|
||||
f"x[{i}]": [
|
||||
self.weights[i],
|
||||
self.prices[i],
|
||||
def get_variable_features(self, names: np.ndarray) -> np.ndarray:
|
||||
return np.vstack(
|
||||
[
|
||||
[[self.weights[i], self.prices[i]] for i in range(self.n)],
|
||||
[0.0, 0.0],
|
||||
]
|
||||
for i in range(len(self.weights))
|
||||
}
|
||||
)
|
||||
|
||||
@overrides
|
||||
def get_variable_categories(self) -> Dict[str, str]:
|
||||
return {f"x[{i}]": "default" for i in range(len(self.weights))}
|
||||
def get_variable_categories(self, names: np.ndarray) -> np.ndarray:
|
||||
return np.array(
|
||||
["default" if n.decode().startswith("x") else "" for n in names],
|
||||
dtype="S",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user