mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Request constraint features/categories in bulk
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
import logging
|
||||
from typing import Dict, Hashable, List, Tuple, Optional, Any, FrozenSet, Set
|
||||
from typing import Dict, Hashable, List, Tuple, Optional, Any, Set
|
||||
|
||||
import numpy as np
|
||||
from overrides import overrides
|
||||
@@ -50,9 +50,14 @@ class DynamicConstraintsComponent(Component):
|
||||
x: Dict[Hashable, List[List[float]]] = {}
|
||||
y: Dict[Hashable, List[List[bool]]] = {}
|
||||
cids: Dict[Hashable, List[str]] = {}
|
||||
constr_categories_dict = instance.get_constraint_categories()
|
||||
constr_features_dict = instance.get_constraint_features()
|
||||
for cid in self.known_cids:
|
||||
# Initialize categories
|
||||
category = instance.get_constraint_category(cid)
|
||||
if cid in constr_categories_dict:
|
||||
category = constr_categories_dict[cid]
|
||||
else:
|
||||
category = cid
|
||||
if category is None:
|
||||
continue
|
||||
if category not in x:
|
||||
@@ -65,7 +70,8 @@ class DynamicConstraintsComponent(Component):
|
||||
assert sample.after_load is not None
|
||||
assert sample.after_load.instance is not None
|
||||
features.extend(sample.after_load.instance.to_list())
|
||||
features.extend(instance.get_constraint_features(cid))
|
||||
if cid in constr_features_dict:
|
||||
features.extend(constr_features_dict[cid])
|
||||
for ci in features:
|
||||
assert isinstance(ci, float), (
|
||||
f"Constraint features must be a list of floats. "
|
||||
@@ -164,10 +170,11 @@ class DynamicConstraintsComponent(Component):
|
||||
tn: Dict[Hashable, int] = {}
|
||||
fp: Dict[Hashable, int] = {}
|
||||
fn: Dict[Hashable, int] = {}
|
||||
constr_categories_dict = instance.get_constraint_categories()
|
||||
for cid in self.known_cids:
|
||||
category = instance.get_constraint_category(cid)
|
||||
if category is None:
|
||||
if cid not in constr_categories_dict:
|
||||
continue
|
||||
category = constr_categories_dict[cid]
|
||||
if category not in tp.keys():
|
||||
tp[category] = 0
|
||||
tn[category] = 0
|
||||
|
||||
@@ -237,16 +237,25 @@ class FeaturesExtractor:
|
||||
user_features: List[Optional[List[float]]] = []
|
||||
categories: List[Optional[Hashable]] = []
|
||||
lazy: List[bool] = []
|
||||
constr_categories_dict = instance.get_constraint_categories()
|
||||
constr_features_dict = instance.get_constraint_features()
|
||||
|
||||
for (cidx, cname) in enumerate(features.constraints.names):
|
||||
category: Optional[Hashable] = cname
|
||||
if cname in constr_categories_dict:
|
||||
category = constr_categories_dict[cname]
|
||||
if category is None:
|
||||
user_features.append(None)
|
||||
categories.append(None)
|
||||
continue
|
||||
assert isinstance(category, collections.Hashable), (
|
||||
f"Constraint category must be hashable. "
|
||||
f"Found {type(category).__name__} instead for cname={cname}.",
|
||||
)
|
||||
categories.append(category)
|
||||
cf: Optional[List[float]] = None
|
||||
category: Optional[Hashable] = instance.get_constraint_category(cname)
|
||||
if category is not None:
|
||||
categories.append(category)
|
||||
assert isinstance(category, collections.Hashable), (
|
||||
f"Constraint category must be hashable. "
|
||||
f"Found {type(category).__name__} instead for cname={cname}.",
|
||||
)
|
||||
cf = instance.get_constraint_features(cname)
|
||||
if cname in constr_features_dict:
|
||||
cf = constr_features_dict[cname]
|
||||
if isinstance(cf, np.ndarray):
|
||||
cf = cf.tolist()
|
||||
assert isinstance(cf, list), (
|
||||
@@ -258,10 +267,8 @@ class FeaturesExtractor:
|
||||
f"Constraint features must be a list of numbers. "
|
||||
f"Found {type(f).__name__} instead for cname={cname}."
|
||||
)
|
||||
user_features.append(list(cf))
|
||||
else:
|
||||
user_features.append(None)
|
||||
categories.append(None)
|
||||
cf = list(cf)
|
||||
user_features.append(cf)
|
||||
if has_static_lazy:
|
||||
lazy.append(instance.is_constraint_lazy(cname))
|
||||
else:
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Optional, Hashable, TYPE_CHECKING, Dict
|
||||
from typing import Any, List, Hashable, TYPE_CHECKING, Dict
|
||||
|
||||
from miplearn.features import Sample
|
||||
|
||||
@@ -96,11 +96,11 @@ class Instance(ABC):
|
||||
"""
|
||||
return {}
|
||||
|
||||
def get_constraint_features(self, cid: str) -> List[float]:
|
||||
return [0.0]
|
||||
def get_constraint_features(self) -> Dict[str, List[float]]:
|
||||
return {}
|
||||
|
||||
def get_constraint_category(self, cid: str) -> Optional[Hashable]:
|
||||
return cid
|
||||
def get_constraint_categories(self) -> Dict[str, Hashable]:
|
||||
return {}
|
||||
|
||||
def has_static_lazy_constraints(self) -> bool:
|
||||
return False
|
||||
|
||||
@@ -56,14 +56,14 @@ class PickleGzInstance(Instance):
|
||||
return self.instance.get_variable_categories()
|
||||
|
||||
@overrides
|
||||
def get_constraint_features(self, cid: str) -> Optional[List[float]]:
|
||||
def get_constraint_features(self) -> Dict[str, List[float]]:
|
||||
assert self.instance is not None
|
||||
return self.instance.get_constraint_features(cid)
|
||||
return self.instance.get_constraint_features()
|
||||
|
||||
@overrides
|
||||
def get_constraint_category(self, cid: str) -> Optional[Hashable]:
|
||||
def get_constraint_categories(self) -> Dict[str, Hashable]:
|
||||
assert self.instance is not None
|
||||
return self.instance.get_constraint_category(cid)
|
||||
return self.instance.get_constraint_categories()
|
||||
|
||||
@overrides
|
||||
def has_static_lazy_constraints(self) -> bool:
|
||||
|
||||
Reference in New Issue
Block a user