mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Request constraint features/categories in bulk
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
import logging
|
||||
from typing import Dict, Hashable, List, Tuple, Optional, Any, FrozenSet, Set
|
||||
from typing import Dict, Hashable, List, Tuple, Optional, Any, Set
|
||||
|
||||
import numpy as np
|
||||
from overrides import overrides
|
||||
@@ -50,9 +50,14 @@ class DynamicConstraintsComponent(Component):
|
||||
x: Dict[Hashable, List[List[float]]] = {}
|
||||
y: Dict[Hashable, List[List[bool]]] = {}
|
||||
cids: Dict[Hashable, List[str]] = {}
|
||||
constr_categories_dict = instance.get_constraint_categories()
|
||||
constr_features_dict = instance.get_constraint_features()
|
||||
for cid in self.known_cids:
|
||||
# Initialize categories
|
||||
category = instance.get_constraint_category(cid)
|
||||
if cid in constr_categories_dict:
|
||||
category = constr_categories_dict[cid]
|
||||
else:
|
||||
category = cid
|
||||
if category is None:
|
||||
continue
|
||||
if category not in x:
|
||||
@@ -65,7 +70,8 @@ class DynamicConstraintsComponent(Component):
|
||||
assert sample.after_load is not None
|
||||
assert sample.after_load.instance is not None
|
||||
features.extend(sample.after_load.instance.to_list())
|
||||
features.extend(instance.get_constraint_features(cid))
|
||||
if cid in constr_features_dict:
|
||||
features.extend(constr_features_dict[cid])
|
||||
for ci in features:
|
||||
assert isinstance(ci, float), (
|
||||
f"Constraint features must be a list of floats. "
|
||||
@@ -164,10 +170,11 @@ class DynamicConstraintsComponent(Component):
|
||||
tn: Dict[Hashable, int] = {}
|
||||
fp: Dict[Hashable, int] = {}
|
||||
fn: Dict[Hashable, int] = {}
|
||||
constr_categories_dict = instance.get_constraint_categories()
|
||||
for cid in self.known_cids:
|
||||
category = instance.get_constraint_category(cid)
|
||||
if category is None:
|
||||
if cid not in constr_categories_dict:
|
||||
continue
|
||||
category = constr_categories_dict[cid]
|
||||
if category not in tp.keys():
|
||||
tp[category] = 0
|
||||
tn[category] = 0
|
||||
|
||||
Reference in New Issue
Block a user