mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Make get_variable_{categories,features} return np.ndarray
This commit is contained in:
@@ -98,11 +98,13 @@ class MultiKnapsackInstance(Instance):
|
||||
return np.array([float(np.mean(self.prices))] + list(self.capacities))
|
||||
|
||||
@overrides
|
||||
def get_variable_features(self) -> Dict[str, List[float]]:
|
||||
return {
|
||||
f"x[{i}]": [self.prices[i] + list(self.weights[:, i])]
|
||||
for i in range(self.n)
|
||||
}
|
||||
def get_variable_features(self, names: np.ndarray) -> np.ndarray:
|
||||
features = []
|
||||
for i in range(len(self.weights)):
|
||||
f = [self.prices[i]]
|
||||
f.extend(self.weights[:, i])
|
||||
features.append(f)
|
||||
return np.array(features)
|
||||
|
||||
|
||||
# noinspection PyPep8Naming
|
||||
|
||||
@@ -66,9 +66,11 @@ class MaxWeightStableSetInstance(Instance):
|
||||
return model
|
||||
|
||||
@overrides
|
||||
def get_variable_features(self) -> Dict[str, List[float]]:
|
||||
features = {}
|
||||
for v1 in self.nodes:
|
||||
def get_variable_features(self, names: np.ndarray) -> np.ndarray:
|
||||
features = []
|
||||
assert len(names) == len(self.nodes)
|
||||
for i, v1 in enumerate(self.nodes):
|
||||
assert names[i] == f"x[{v1}]".encode()
|
||||
neighbor_weights = [0.0] * 15
|
||||
neighbor_degrees = [100.0] * 15
|
||||
for v2 in self.graph.neighbors(v1):
|
||||
@@ -80,12 +82,12 @@ class MaxWeightStableSetInstance(Instance):
|
||||
f += neighbor_weights[:5]
|
||||
f += neighbor_degrees[:5]
|
||||
f += [self.graph.degree(v1)]
|
||||
features[f"x[{v1}]"] = f
|
||||
return features
|
||||
features.append(f)
|
||||
return np.array(features)
|
||||
|
||||
@overrides
|
||||
def get_variable_categories(self) -> Dict[str, str]:
|
||||
return {f"x[{v}]": "default" for v in self.nodes}
|
||||
def get_variable_categories(self, names: np.ndarray) -> np.ndarray:
|
||||
return np.array(["default" for _ in names], dtype="S")
|
||||
|
||||
|
||||
class MaxWeightStableSetGenerator:
|
||||
|
||||
@@ -80,10 +80,6 @@ class TravelingSalesmanInstance(Instance):
|
||||
)
|
||||
return model
|
||||
|
||||
@overrides
|
||||
def get_variable_categories(self) -> Dict[str, str]:
|
||||
return {f"x[{e}]": f"x[{e}]" for e in self.edges}
|
||||
|
||||
@overrides
|
||||
def find_violated_lazy_constraints(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user