Make get_variable_{categories,features} return np.ndarray

This commit is contained in:
2021-08-09 15:19:53 -05:00
parent 56b39b6c9c
commit 895cb962b6
13 changed files with 165 additions and 155 deletions

View File

@@ -9,7 +9,7 @@ from p_tqdm import p_umap
from miplearn.features.sample import Sample
from miplearn.instance.base import Instance
from miplearn.types import LearningSolveStats
from miplearn.types import LearningSolveStats, Category
if TYPE_CHECKING:
from miplearn.solvers.learning import LearningSolver
@@ -101,8 +101,8 @@ class Component:
def fit_xy(
self,
x: Dict[str, np.ndarray],
y: Dict[str, np.ndarray],
x: Dict[Category, np.ndarray],
y: Dict[Category, np.ndarray],
) -> None:
"""
Given two dictionaries x and y, mapping the name of the category to matrices

View File

@@ -47,8 +47,8 @@ class PrimalSolutionComponent(Component):
assert isinstance(threshold, Threshold)
assert mode in ["exact", "heuristic"]
self.mode = mode
self.classifiers: Dict[str, Classifier] = {}
self.thresholds: Dict[str, Threshold] = {}
self.classifiers: Dict[Category, Classifier] = {}
self.thresholds: Dict[Category, Threshold] = {}
self.threshold_prototype = threshold
self.classifier_prototype = classifier
@@ -96,7 +96,7 @@ class PrimalSolutionComponent(Component):
def sample_predict(self, sample: Sample) -> Solution:
var_names = sample.get_array("static_var_names")
var_categories = sample.get_vector("static_var_categories")
var_categories = sample.get_array("static_var_categories")
assert var_names is not None
assert var_categories is not None
@@ -120,7 +120,7 @@ class PrimalSolutionComponent(Component):
# Convert y_pred into solution
solution: Solution = {v: None for v in var_names}
category_offset: Dict[str, int] = {cat: 0 for cat in x.keys()}
category_offset: Dict[Category, int] = {cat: 0 for cat in x.keys()}
for (i, var_name) in enumerate(var_names):
category = var_categories[i]
if category not in category_offset:
@@ -146,7 +146,7 @@ class PrimalSolutionComponent(Component):
mip_var_values = sample.get_array("mip_var_values")
var_features = sample.get_vector_list("lp_var_features")
var_names = sample.get_array("static_var_names")
var_categories = sample.get_vector("static_var_categories")
var_categories = sample.get_array("static_var_categories")
if var_features is None:
var_features = sample.get_vector_list("static_var_features")
assert instance_features is not None
@@ -157,7 +157,7 @@ class PrimalSolutionComponent(Component):
for (i, var_name) in enumerate(var_names):
# Initialize categories
category = var_categories[i]
if category is None:
if len(category) == 0:
continue
if category not in x.keys():
x[category] = []
@@ -176,7 +176,7 @@ class PrimalSolutionComponent(Component):
f"Variable {var_name} has non-binary value {opt_value} in the "
"optimal solution. Predicting values of non-binary "
"variables is not currently supported. Please set its "
"category to None."
"category to ''."
)
y[category].append([opt_value < 0.5, opt_value >= 0.5])
return x, y
@@ -230,8 +230,8 @@ class PrimalSolutionComponent(Component):
@overrides
def fit_xy(
self,
x: Dict[str, np.ndarray],
y: Dict[str, np.ndarray],
x: Dict[Category, np.ndarray],
y: Dict[Category, np.ndarray],
) -> None:
for category in x.keys():
clf = self.classifier_prototype.clone()