mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Make get_variable_{categories,features} return np.ndarray
This commit is contained in:
@@ -9,7 +9,7 @@ from p_tqdm import p_umap
|
||||
|
||||
from miplearn.features.sample import Sample
|
||||
from miplearn.instance.base import Instance
|
||||
from miplearn.types import LearningSolveStats
|
||||
from miplearn.types import LearningSolveStats, Category
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from miplearn.solvers.learning import LearningSolver
|
||||
@@ -101,8 +101,8 @@ class Component:
|
||||
|
||||
def fit_xy(
|
||||
self,
|
||||
x: Dict[str, np.ndarray],
|
||||
y: Dict[str, np.ndarray],
|
||||
x: Dict[Category, np.ndarray],
|
||||
y: Dict[Category, np.ndarray],
|
||||
) -> None:
|
||||
"""
|
||||
Given two dictionaries x and y, mapping the name of the category to matrices
|
||||
|
||||
@@ -47,8 +47,8 @@ class PrimalSolutionComponent(Component):
|
||||
assert isinstance(threshold, Threshold)
|
||||
assert mode in ["exact", "heuristic"]
|
||||
self.mode = mode
|
||||
self.classifiers: Dict[str, Classifier] = {}
|
||||
self.thresholds: Dict[str, Threshold] = {}
|
||||
self.classifiers: Dict[Category, Classifier] = {}
|
||||
self.thresholds: Dict[Category, Threshold] = {}
|
||||
self.threshold_prototype = threshold
|
||||
self.classifier_prototype = classifier
|
||||
|
||||
@@ -96,7 +96,7 @@ class PrimalSolutionComponent(Component):
|
||||
|
||||
def sample_predict(self, sample: Sample) -> Solution:
|
||||
var_names = sample.get_array("static_var_names")
|
||||
var_categories = sample.get_vector("static_var_categories")
|
||||
var_categories = sample.get_array("static_var_categories")
|
||||
assert var_names is not None
|
||||
assert var_categories is not None
|
||||
|
||||
@@ -120,7 +120,7 @@ class PrimalSolutionComponent(Component):
|
||||
|
||||
# Convert y_pred into solution
|
||||
solution: Solution = {v: None for v in var_names}
|
||||
category_offset: Dict[str, int] = {cat: 0 for cat in x.keys()}
|
||||
category_offset: Dict[Category, int] = {cat: 0 for cat in x.keys()}
|
||||
for (i, var_name) in enumerate(var_names):
|
||||
category = var_categories[i]
|
||||
if category not in category_offset:
|
||||
@@ -146,7 +146,7 @@ class PrimalSolutionComponent(Component):
|
||||
mip_var_values = sample.get_array("mip_var_values")
|
||||
var_features = sample.get_vector_list("lp_var_features")
|
||||
var_names = sample.get_array("static_var_names")
|
||||
var_categories = sample.get_vector("static_var_categories")
|
||||
var_categories = sample.get_array("static_var_categories")
|
||||
if var_features is None:
|
||||
var_features = sample.get_vector_list("static_var_features")
|
||||
assert instance_features is not None
|
||||
@@ -157,7 +157,7 @@ class PrimalSolutionComponent(Component):
|
||||
for (i, var_name) in enumerate(var_names):
|
||||
# Initialize categories
|
||||
category = var_categories[i]
|
||||
if category is None:
|
||||
if len(category) == 0:
|
||||
continue
|
||||
if category not in x.keys():
|
||||
x[category] = []
|
||||
@@ -176,7 +176,7 @@ class PrimalSolutionComponent(Component):
|
||||
f"Variable {var_name} has non-binary value {opt_value} in the "
|
||||
"optimal solution. Predicting values of non-binary "
|
||||
"variables is not currently supported. Please set its "
|
||||
"category to None."
|
||||
"category to ''."
|
||||
)
|
||||
y[category].append([opt_value < 0.5, opt_value >= 0.5])
|
||||
return x, y
|
||||
@@ -230,8 +230,8 @@ class PrimalSolutionComponent(Component):
|
||||
@overrides
|
||||
def fit_xy(
|
||||
self,
|
||||
x: Dict[str, np.ndarray],
|
||||
y: Dict[str, np.ndarray],
|
||||
x: Dict[Category, np.ndarray],
|
||||
y: Dict[Category, np.ndarray],
|
||||
) -> None:
|
||||
for category in x.keys():
|
||||
clf = self.classifier_prototype.clone()
|
||||
|
||||
Reference in New Issue
Block a user