mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-08 02:18:51 -06:00
Request variable features/categories in bulk
This commit is contained in:
@@ -6,12 +6,11 @@ import gc
|
||||
import gzip
|
||||
import os
|
||||
import pickle
|
||||
from typing import Optional, Any, List, Hashable, cast, IO, TYPE_CHECKING
|
||||
from typing import Optional, Any, List, Hashable, cast, IO, TYPE_CHECKING, Dict
|
||||
|
||||
from overrides import overrides
|
||||
|
||||
from miplearn.instance.base import logger, Instance
|
||||
from miplearn.types import VariableName, Category
|
||||
from miplearn.instance.base import Instance
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from miplearn.solvers.learning import InternalSolver
|
||||
@@ -47,14 +46,14 @@ class PickleGzInstance(Instance):
|
||||
return self.instance.get_instance_features()
|
||||
|
||||
@overrides
|
||||
def get_variable_features(self, var_name: VariableName) -> List[float]:
|
||||
def get_variable_features(self) -> Dict[str, List[float]]:
|
||||
assert self.instance is not None
|
||||
return self.instance.get_variable_features(var_name)
|
||||
return self.instance.get_variable_features()
|
||||
|
||||
@overrides
|
||||
def get_variable_category(self, var_name: VariableName) -> Optional[Category]:
|
||||
def get_variable_categories(self) -> Dict[str, Hashable]:
|
||||
assert self.instance is not None
|
||||
return self.instance.get_variable_category(var_name)
|
||||
return self.instance.get_variable_categories()
|
||||
|
||||
@overrides
|
||||
def get_constraint_features(self, cid: str) -> Optional[List[float]]:
|
||||
|
||||
Reference in New Issue
Block a user