mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Request constraint features/categories in bulk
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
import logging
|
||||
from typing import Dict, Hashable, List, Tuple, Optional, Any, FrozenSet, Set
|
||||
from typing import Dict, Hashable, List, Tuple, Optional, Any, Set
|
||||
|
||||
import numpy as np
|
||||
from overrides import overrides
|
||||
@@ -50,9 +50,14 @@ class DynamicConstraintsComponent(Component):
|
||||
x: Dict[Hashable, List[List[float]]] = {}
|
||||
y: Dict[Hashable, List[List[bool]]] = {}
|
||||
cids: Dict[Hashable, List[str]] = {}
|
||||
constr_categories_dict = instance.get_constraint_categories()
|
||||
constr_features_dict = instance.get_constraint_features()
|
||||
for cid in self.known_cids:
|
||||
# Initialize categories
|
||||
category = instance.get_constraint_category(cid)
|
||||
if cid in constr_categories_dict:
|
||||
category = constr_categories_dict[cid]
|
||||
else:
|
||||
category = cid
|
||||
if category is None:
|
||||
continue
|
||||
if category not in x:
|
||||
@@ -65,7 +70,8 @@ class DynamicConstraintsComponent(Component):
|
||||
assert sample.after_load is not None
|
||||
assert sample.after_load.instance is not None
|
||||
features.extend(sample.after_load.instance.to_list())
|
||||
features.extend(instance.get_constraint_features(cid))
|
||||
if cid in constr_features_dict:
|
||||
features.extend(constr_features_dict[cid])
|
||||
for ci in features:
|
||||
assert isinstance(ci, float), (
|
||||
f"Constraint features must be a list of floats. "
|
||||
@@ -164,10 +170,11 @@ class DynamicConstraintsComponent(Component):
|
||||
tn: Dict[Hashable, int] = {}
|
||||
fp: Dict[Hashable, int] = {}
|
||||
fn: Dict[Hashable, int] = {}
|
||||
constr_categories_dict = instance.get_constraint_categories()
|
||||
for cid in self.known_cids:
|
||||
category = instance.get_constraint_category(cid)
|
||||
if category is None:
|
||||
if cid not in constr_categories_dict:
|
||||
continue
|
||||
category = constr_categories_dict[cid]
|
||||
if category not in tp.keys():
|
||||
tp[category] = 0
|
||||
tn[category] = 0
|
||||
|
||||
@@ -237,16 +237,25 @@ class FeaturesExtractor:
|
||||
user_features: List[Optional[List[float]]] = []
|
||||
categories: List[Optional[Hashable]] = []
|
||||
lazy: List[bool] = []
|
||||
constr_categories_dict = instance.get_constraint_categories()
|
||||
constr_features_dict = instance.get_constraint_features()
|
||||
|
||||
for (cidx, cname) in enumerate(features.constraints.names):
|
||||
cf: Optional[List[float]] = None
|
||||
category: Optional[Hashable] = instance.get_constraint_category(cname)
|
||||
if category is not None:
|
||||
categories.append(category)
|
||||
category: Optional[Hashable] = cname
|
||||
if cname in constr_categories_dict:
|
||||
category = constr_categories_dict[cname]
|
||||
if category is None:
|
||||
user_features.append(None)
|
||||
categories.append(None)
|
||||
continue
|
||||
assert isinstance(category, collections.Hashable), (
|
||||
f"Constraint category must be hashable. "
|
||||
f"Found {type(category).__name__} instead for cname={cname}.",
|
||||
)
|
||||
cf = instance.get_constraint_features(cname)
|
||||
categories.append(category)
|
||||
cf: Optional[List[float]] = None
|
||||
if cname in constr_features_dict:
|
||||
cf = constr_features_dict[cname]
|
||||
if isinstance(cf, np.ndarray):
|
||||
cf = cf.tolist()
|
||||
assert isinstance(cf, list), (
|
||||
@@ -258,10 +267,8 @@ class FeaturesExtractor:
|
||||
f"Constraint features must be a list of numbers. "
|
||||
f"Found {type(f).__name__} instead for cname={cname}."
|
||||
)
|
||||
user_features.append(list(cf))
|
||||
else:
|
||||
user_features.append(None)
|
||||
categories.append(None)
|
||||
cf = list(cf)
|
||||
user_features.append(cf)
|
||||
if has_static_lazy:
|
||||
lazy.append(instance.is_constraint_lazy(cname))
|
||||
else:
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Optional, Hashable, TYPE_CHECKING, Dict
|
||||
from typing import Any, List, Hashable, TYPE_CHECKING, Dict
|
||||
|
||||
from miplearn.features import Sample
|
||||
|
||||
@@ -96,11 +96,11 @@ class Instance(ABC):
|
||||
"""
|
||||
return {}
|
||||
|
||||
def get_constraint_features(self, cid: str) -> List[float]:
|
||||
return [0.0]
|
||||
def get_constraint_features(self) -> Dict[str, List[float]]:
|
||||
return {}
|
||||
|
||||
def get_constraint_category(self, cid: str) -> Optional[Hashable]:
|
||||
return cid
|
||||
def get_constraint_categories(self) -> Dict[str, Hashable]:
|
||||
return {}
|
||||
|
||||
def has_static_lazy_constraints(self) -> bool:
|
||||
return False
|
||||
|
||||
@@ -56,14 +56,14 @@ class PickleGzInstance(Instance):
|
||||
return self.instance.get_variable_categories()
|
||||
|
||||
@overrides
|
||||
def get_constraint_features(self, cid: str) -> Optional[List[float]]:
|
||||
def get_constraint_features(self) -> Dict[str, List[float]]:
|
||||
assert self.instance is not None
|
||||
return self.instance.get_constraint_features(cid)
|
||||
return self.instance.get_constraint_features()
|
||||
|
||||
@overrides
|
||||
def get_constraint_category(self, cid: str) -> Optional[Hashable]:
|
||||
def get_constraint_categories(self) -> Dict[str, Hashable]:
|
||||
assert self.instance is not None
|
||||
return self.instance.get_constraint_category(cid)
|
||||
return self.instance.get_constraint_categories()
|
||||
|
||||
@overrides
|
||||
def has_static_lazy_constraints(self) -> bool:
|
||||
|
||||
@@ -6,7 +6,6 @@ from unittest.mock import Mock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from numpy.testing import assert_array_equal
|
||||
|
||||
from miplearn.classifiers import Classifier
|
||||
from miplearn.classifiers.threshold import MinProbabilityThreshold
|
||||
@@ -42,21 +41,21 @@ def training_instances() -> List[Instance]:
|
||||
instances[0].samples[1].after_load.instance.to_list = Mock( # type: ignore
|
||||
return_value=[5.0]
|
||||
)
|
||||
instances[0].get_constraint_category = Mock( # type: ignore
|
||||
side_effect=lambda cid: {
|
||||
instances[0].get_constraint_categories = Mock( # type: ignore
|
||||
return_value={
|
||||
"c1": "type-a",
|
||||
"c2": "type-a",
|
||||
"c3": "type-b",
|
||||
"c4": "type-b",
|
||||
}[cid]
|
||||
}
|
||||
)
|
||||
instances[0].get_constraint_features = Mock( # type: ignore
|
||||
side_effect=lambda cid: {
|
||||
return_value={
|
||||
"c1": [1.0, 2.0, 3.0],
|
||||
"c2": [4.0, 5.0, 6.0],
|
||||
"c3": [1.0, 2.0],
|
||||
"c4": [3.0, 4.0],
|
||||
}[cid]
|
||||
}
|
||||
)
|
||||
instances[1].samples = [
|
||||
Sample(
|
||||
@@ -67,20 +66,20 @@ def training_instances() -> List[Instance]:
|
||||
instances[1].samples[0].after_load.instance.to_list = Mock( # type: ignore
|
||||
return_value=[8.0]
|
||||
)
|
||||
instances[1].get_constraint_category = Mock( # type: ignore
|
||||
side_effect=lambda cid: {
|
||||
instances[1].get_constraint_categories = Mock( # type: ignore
|
||||
return_value={
|
||||
"c1": None,
|
||||
"c2": "type-a",
|
||||
"c3": "type-b",
|
||||
"c4": "type-b",
|
||||
}[cid]
|
||||
}
|
||||
)
|
||||
instances[1].get_constraint_features = Mock( # type: ignore
|
||||
side_effect=lambda cid: {
|
||||
return_value={
|
||||
"c2": [7.0, 8.0, 9.0],
|
||||
"c3": [5.0, 6.0],
|
||||
"c4": [7.0, 8.0],
|
||||
}[cid]
|
||||
}
|
||||
)
|
||||
return instances
|
||||
|
||||
|
||||
@@ -83,7 +83,7 @@ def test_knapsack() -> None:
|
||||
sa_rhs_up=[2.0],
|
||||
senses=["="],
|
||||
slacks=[0.0],
|
||||
user_features=[[0.0]],
|
||||
user_features=[None],
|
||||
),
|
||||
)
|
||||
assert_equals(
|
||||
|
||||
Reference in New Issue
Block a user