mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-09 19:08:51 -06:00
Remove EnforceOverrides; automatically convert np.ndarray features
This commit is contained in:
@@ -8,6 +8,7 @@ from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Set, List, Hashable
|
||||
|
||||
from miplearn.types import Solution, VariableName, Category
|
||||
import numpy as np
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from miplearn.solvers.internal import InternalSolver
|
||||
@@ -83,6 +84,8 @@ class FeaturesExtractor:
|
||||
f"Found {type(category).__name__} instead for var={var_name}."
|
||||
)
|
||||
user_features = instance.get_variable_features(var_name)
|
||||
if isinstance(user_features, np.ndarray):
|
||||
user_features = user_features.tolist()
|
||||
assert isinstance(user_features, list), (
|
||||
f"Variable features must be a list. "
|
||||
f"Found {type(user_features).__name__} instead for "
|
||||
@@ -115,6 +118,8 @@ class FeaturesExtractor:
|
||||
f"Found {type(category).__name__} instead for cid={cid}.",
|
||||
)
|
||||
user_features = instance.get_constraint_features(cid)
|
||||
if isinstance(user_features, np.ndarray):
|
||||
user_features = user_features.tolist()
|
||||
assert isinstance(user_features, list), (
|
||||
f"Constraint features must be a list. "
|
||||
f"Found {type(user_features).__name__} instead for cid={cid}."
|
||||
@@ -141,6 +146,8 @@ class FeaturesExtractor:
|
||||
) -> InstanceFeatures:
|
||||
assert features.constraints is not None
|
||||
user_features = instance.get_instance_features()
|
||||
if isinstance(user_features, np.ndarray):
|
||||
user_features = user_features.tolist()
|
||||
assert isinstance(user_features, list), (
|
||||
f"Instance features must be a list. "
|
||||
f"Found {type(user_features).__name__} instead."
|
||||
|
||||
Reference in New Issue
Block a user