mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Use np.array for Variables.names
This commit is contained in:
@@ -5,7 +5,7 @@
|
||||
import collections
|
||||
import numbers
|
||||
from math import log, isfinite
|
||||
from typing import TYPE_CHECKING, Dict, Optional, List, Any, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, Optional, List, Any, Tuple, KeysView, cast
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -34,7 +34,7 @@ class FeaturesExtractor:
|
||||
variables = solver.get_variables(with_static=True)
|
||||
constraints = solver.get_constraints(with_static=True, with_lhs=self.with_lhs)
|
||||
sample.put_array("static_var_lower_bounds", variables.lower_bounds)
|
||||
sample.put_vector("static_var_names", variables.names)
|
||||
sample.put_array("static_var_names", variables.names)
|
||||
sample.put_array("static_var_obj_coeffs", variables.obj_coeffs)
|
||||
sample.put_vector("static_var_types", variables.types)
|
||||
sample.put_array("static_var_upper_bounds", variables.upper_bounds)
|
||||
@@ -139,12 +139,29 @@ class FeaturesExtractor:
|
||||
instance: "Instance",
|
||||
sample: Sample,
|
||||
) -> Tuple[List, List]:
|
||||
categories: List[Optional[str]] = []
|
||||
user_features: List[Optional[List[float]]] = []
|
||||
var_features_dict = instance.get_variable_features()
|
||||
var_categories_dict = instance.get_variable_categories()
|
||||
var_names = sample.get_vector("static_var_names")
|
||||
# Query variable names
|
||||
var_names = sample.get_array("static_var_names")
|
||||
assert var_names is not None
|
||||
|
||||
# Query variable features and categories
|
||||
var_features_dict = {
|
||||
v.encode(): f for (v, f) in instance.get_variable_features().items()
|
||||
}
|
||||
var_categories_dict = {
|
||||
v.encode(): f for (v, f) in instance.get_variable_categories().items()
|
||||
}
|
||||
|
||||
# Assert that variables in user-provided dicts actually exist
|
||||
var_names_set = set(var_names)
|
||||
for keys in [var_features_dict.keys(), var_categories_dict.keys()]:
|
||||
for vn in cast(KeysView, keys):
|
||||
assert (
|
||||
vn in var_names_set
|
||||
), f"Variable {vn!r} not found in the problem; {var_names_set}"
|
||||
|
||||
# Assemble into compact lists
|
||||
user_features: List[Optional[List[float]]] = []
|
||||
categories: List[Optional[str]] = []
|
||||
for (i, var_name) in enumerate(var_names):
|
||||
if var_name not in var_categories_dict:
|
||||
user_features.append(None)
|
||||
|
||||
Reference in New Issue
Block a user