Remove EnforceOverrides; automatically convert np.ndarray features

This commit is contained in:
2021-04-08 07:50:16 -05:00
parent 157825a345
commit 6330354c47
6 changed files with 17 additions and 8 deletions

View File

@@ -8,6 +8,7 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Optional, Set, List, Hashable
from miplearn.types import Solution, VariableName, Category
import numpy as np
if TYPE_CHECKING:
from miplearn.solvers.internal import InternalSolver
@@ -83,6 +84,8 @@ class FeaturesExtractor:
f"Found {type(category).__name__} instead for var={var_name}."
)
user_features = instance.get_variable_features(var_name)
if isinstance(user_features, np.ndarray):
user_features = user_features.tolist()
assert isinstance(user_features, list), (
f"Variable features must be a list. "
f"Found {type(user_features).__name__} instead for "
@@ -115,6 +118,8 @@ class FeaturesExtractor:
f"Found {type(category).__name__} instead for cid={cid}.",
)
user_features = instance.get_constraint_features(cid)
if isinstance(user_features, np.ndarray):
user_features = user_features.tolist()
assert isinstance(user_features, list), (
f"Constraint features must be a list. "
f"Found {type(user_features).__name__} instead for cid={cid}."
@@ -141,6 +146,8 @@ class FeaturesExtractor:
) -> InstanceFeatures:
assert features.constraints is not None
user_features = instance.get_instance_features()
if isinstance(user_features, np.ndarray):
user_features = user_features.tolist()
assert isinstance(user_features, list), (
f"Instance features must be a list. "
f"Found {type(user_features).__name__} instead."