Use np.array for Variables.names

This commit is contained in:
2021-08-08 07:24:14 -05:00
parent f69067aafd
commit 7d55d6f34c
10 changed files with 96 additions and 76 deletions

View File

@@ -5,7 +5,7 @@
import collections
import numbers
from math import log, isfinite
from typing import TYPE_CHECKING, Dict, Optional, List, Any, Tuple
from typing import TYPE_CHECKING, Dict, Optional, List, Any, Tuple, KeysView, cast
import numpy as np
@@ -34,7 +34,7 @@ class FeaturesExtractor:
variables = solver.get_variables(with_static=True)
constraints = solver.get_constraints(with_static=True, with_lhs=self.with_lhs)
sample.put_array("static_var_lower_bounds", variables.lower_bounds)
sample.put_vector("static_var_names", variables.names)
sample.put_array("static_var_names", variables.names)
sample.put_array("static_var_obj_coeffs", variables.obj_coeffs)
sample.put_vector("static_var_types", variables.types)
sample.put_array("static_var_upper_bounds", variables.upper_bounds)
@@ -139,12 +139,29 @@ class FeaturesExtractor:
instance: "Instance",
sample: Sample,
) -> Tuple[List, List]:
categories: List[Optional[str]] = []
user_features: List[Optional[List[float]]] = []
var_features_dict = instance.get_variable_features()
var_categories_dict = instance.get_variable_categories()
var_names = sample.get_vector("static_var_names")
# Query variable names
var_names = sample.get_array("static_var_names")
assert var_names is not None
# Query variable features and categories
var_features_dict = {
v.encode(): f for (v, f) in instance.get_variable_features().items()
}
var_categories_dict = {
v.encode(): f for (v, f) in instance.get_variable_categories().items()
}
# Assert that variables in user-provided dicts actually exist
var_names_set = set(var_names)
for keys in [var_features_dict.keys(), var_categories_dict.keys()]:
for vn in cast(KeysView, keys):
assert (
vn in var_names_set
), f"Variable {vn!r} not found in the problem; {var_names_set}"
# Assemble into compact lists
user_features: List[Optional[List[float]]] = []
categories: List[Optional[str]] = []
for (i, var_name) in enumerate(var_names):
if var_name not in var_categories_dict:
user_features.append(None)