Request constraint features/categories in bulk

This commit is contained in:
2021-06-29 09:54:35 -05:00
parent 8118ab4110
commit a5092cc2b9
6 changed files with 51 additions and 38 deletions

View File

@@ -3,7 +3,7 @@
# Released under the modified BSD license. See COPYING.md for more details.
import logging
from typing import Dict, Hashable, List, Tuple, Optional, Any, FrozenSet, Set
from typing import Dict, Hashable, List, Tuple, Optional, Any, Set
import numpy as np
from overrides import overrides
@@ -50,9 +50,14 @@ class DynamicConstraintsComponent(Component):
x: Dict[Hashable, List[List[float]]] = {}
y: Dict[Hashable, List[List[bool]]] = {}
cids: Dict[Hashable, List[str]] = {}
constr_categories_dict = instance.get_constraint_categories()
constr_features_dict = instance.get_constraint_features()
for cid in self.known_cids:
# Initialize categories
category = instance.get_constraint_category(cid)
if cid in constr_categories_dict:
category = constr_categories_dict[cid]
else:
category = cid
if category is None:
continue
if category not in x:
@@ -65,7 +70,8 @@ class DynamicConstraintsComponent(Component):
assert sample.after_load is not None
assert sample.after_load.instance is not None
features.extend(sample.after_load.instance.to_list())
features.extend(instance.get_constraint_features(cid))
if cid in constr_features_dict:
features.extend(constr_features_dict[cid])
for ci in features:
assert isinstance(ci, float), (
f"Constraint features must be a list of floats. "
@@ -164,10 +170,11 @@ class DynamicConstraintsComponent(Component):
tn: Dict[Hashable, int] = {}
fp: Dict[Hashable, int] = {}
fn: Dict[Hashable, int] = {}
constr_categories_dict = instance.get_constraint_categories()
for cid in self.known_cids:
category = instance.get_constraint_category(cid)
if category is None:
if cid not in constr_categories_dict:
continue
category = constr_categories_dict[cid]
if category not in tp.keys():
tp[category] = 0
tn[category] = 0

View File

@@ -237,16 +237,25 @@ class FeaturesExtractor:
user_features: List[Optional[List[float]]] = []
categories: List[Optional[Hashable]] = []
lazy: List[bool] = []
constr_categories_dict = instance.get_constraint_categories()
constr_features_dict = instance.get_constraint_features()
for (cidx, cname) in enumerate(features.constraints.names):
category: Optional[Hashable] = cname
if cname in constr_categories_dict:
category = constr_categories_dict[cname]
if category is None:
user_features.append(None)
categories.append(None)
continue
assert isinstance(category, collections.Hashable), (
f"Constraint category must be hashable. "
f"Found {type(category).__name__} instead for cname={cname}.",
)
categories.append(category)
cf: Optional[List[float]] = None
category: Optional[Hashable] = instance.get_constraint_category(cname)
if category is not None:
categories.append(category)
assert isinstance(category, collections.Hashable), (
f"Constraint category must be hashable. "
f"Found {type(category).__name__} instead for cname={cname}.",
)
cf = instance.get_constraint_features(cname)
if cname in constr_features_dict:
cf = constr_features_dict[cname]
if isinstance(cf, np.ndarray):
cf = cf.tolist()
assert isinstance(cf, list), (
@@ -258,10 +267,8 @@ class FeaturesExtractor:
f"Constraint features must be a list of numbers. "
f"Found {type(f).__name__} instead for cname={cname}."
)
user_features.append(list(cf))
else:
user_features.append(None)
categories.append(None)
cf = list(cf)
user_features.append(cf)
if has_static_lazy:
lazy.append(instance.is_constraint_lazy(cname))
else:

View File

@@ -4,7 +4,7 @@
import logging
from abc import ABC, abstractmethod
from typing import Any, List, Optional, Hashable, TYPE_CHECKING, Dict
from typing import Any, List, Hashable, TYPE_CHECKING, Dict
from miplearn.features import Sample
@@ -96,11 +96,11 @@ class Instance(ABC):
"""
return {}
def get_constraint_features(self, cid: str) -> List[float]:
return [0.0]
def get_constraint_features(self) -> Dict[str, List[float]]:
return {}
def get_constraint_category(self, cid: str) -> Optional[Hashable]:
return cid
def get_constraint_categories(self) -> Dict[str, Hashable]:
return {}
def has_static_lazy_constraints(self) -> bool:
return False

View File

@@ -56,14 +56,14 @@ class PickleGzInstance(Instance):
return self.instance.get_variable_categories()
@overrides
def get_constraint_features(self, cid: str) -> Optional[List[float]]:
def get_constraint_features(self) -> Dict[str, List[float]]:
assert self.instance is not None
return self.instance.get_constraint_features(cid)
return self.instance.get_constraint_features()
@overrides
def get_constraint_category(self, cid: str) -> Optional[Hashable]:
def get_constraint_categories(self) -> Dict[str, Hashable]:
assert self.instance is not None
return self.instance.get_constraint_category(cid)
return self.instance.get_constraint_categories()
@overrides
def has_static_lazy_constraints(self) -> bool:

View File

@@ -6,7 +6,6 @@ from unittest.mock import Mock
import numpy as np
import pytest
from numpy.testing import assert_array_equal
from miplearn.classifiers import Classifier
from miplearn.classifiers.threshold import MinProbabilityThreshold
@@ -42,21 +41,21 @@ def training_instances() -> List[Instance]:
instances[0].samples[1].after_load.instance.to_list = Mock( # type: ignore
return_value=[5.0]
)
instances[0].get_constraint_category = Mock( # type: ignore
side_effect=lambda cid: {
instances[0].get_constraint_categories = Mock( # type: ignore
return_value={
"c1": "type-a",
"c2": "type-a",
"c3": "type-b",
"c4": "type-b",
}[cid]
}
)
instances[0].get_constraint_features = Mock( # type: ignore
side_effect=lambda cid: {
return_value={
"c1": [1.0, 2.0, 3.0],
"c2": [4.0, 5.0, 6.0],
"c3": [1.0, 2.0],
"c4": [3.0, 4.0],
}[cid]
}
)
instances[1].samples = [
Sample(
@@ -67,20 +66,20 @@ def training_instances() -> List[Instance]:
instances[1].samples[0].after_load.instance.to_list = Mock( # type: ignore
return_value=[8.0]
)
instances[1].get_constraint_category = Mock( # type: ignore
side_effect=lambda cid: {
instances[1].get_constraint_categories = Mock( # type: ignore
return_value={
"c1": None,
"c2": "type-a",
"c3": "type-b",
"c4": "type-b",
}[cid]
}
)
instances[1].get_constraint_features = Mock( # type: ignore
side_effect=lambda cid: {
return_value={
"c2": [7.0, 8.0, 9.0],
"c3": [5.0, 6.0],
"c4": [7.0, 8.0],
}[cid]
}
)
return instances

View File

@@ -83,7 +83,7 @@ def test_knapsack() -> None:
sa_rhs_up=[2.0],
senses=["="],
slacks=[0.0],
user_features=[[0.0]],
user_features=[None],
),
)
assert_equals(