Request constraint features/categories in bulk

master
Alinson S. Xavier 4 years ago
parent 8118ab4110
commit a5092cc2b9
No known key found for this signature in database
GPG Key ID: DCA0DAD4D2F58624

@ -3,7 +3,7 @@
# Released under the modified BSD license. See COPYING.md for more details.
import logging
from typing import Dict, Hashable, List, Tuple, Optional, Any, FrozenSet, Set
from typing import Dict, Hashable, List, Tuple, Optional, Any, Set
import numpy as np
from overrides import overrides
@ -50,9 +50,14 @@ class DynamicConstraintsComponent(Component):
x: Dict[Hashable, List[List[float]]] = {}
y: Dict[Hashable, List[List[bool]]] = {}
cids: Dict[Hashable, List[str]] = {}
constr_categories_dict = instance.get_constraint_categories()
constr_features_dict = instance.get_constraint_features()
for cid in self.known_cids:
# Initialize categories
category = instance.get_constraint_category(cid)
if cid in constr_categories_dict:
category = constr_categories_dict[cid]
else:
category = cid
if category is None:
continue
if category not in x:
@ -65,7 +70,8 @@ class DynamicConstraintsComponent(Component):
assert sample.after_load is not None
assert sample.after_load.instance is not None
features.extend(sample.after_load.instance.to_list())
features.extend(instance.get_constraint_features(cid))
if cid in constr_features_dict:
features.extend(constr_features_dict[cid])
for ci in features:
assert isinstance(ci, float), (
f"Constraint features must be a list of floats. "
@ -164,10 +170,11 @@ class DynamicConstraintsComponent(Component):
tn: Dict[Hashable, int] = {}
fp: Dict[Hashable, int] = {}
fn: Dict[Hashable, int] = {}
constr_categories_dict = instance.get_constraint_categories()
for cid in self.known_cids:
category = instance.get_constraint_category(cid)
if category is None:
if cid not in constr_categories_dict:
continue
category = constr_categories_dict[cid]
if category not in tp.keys():
tp[category] = 0
tn[category] = 0

@ -237,16 +237,25 @@ class FeaturesExtractor:
user_features: List[Optional[List[float]]] = []
categories: List[Optional[Hashable]] = []
lazy: List[bool] = []
constr_categories_dict = instance.get_constraint_categories()
constr_features_dict = instance.get_constraint_features()
for (cidx, cname) in enumerate(features.constraints.names):
cf: Optional[List[float]] = None
category: Optional[Hashable] = instance.get_constraint_category(cname)
if category is not None:
categories.append(category)
category: Optional[Hashable] = cname
if cname in constr_categories_dict:
category = constr_categories_dict[cname]
if category is None:
user_features.append(None)
categories.append(None)
continue
assert isinstance(category, collections.Hashable), (
f"Constraint category must be hashable. "
f"Found {type(category).__name__} instead for cname={cname}.",
)
cf = instance.get_constraint_features(cname)
categories.append(category)
cf: Optional[List[float]] = None
if cname in constr_features_dict:
cf = constr_features_dict[cname]
if isinstance(cf, np.ndarray):
cf = cf.tolist()
assert isinstance(cf, list), (
@ -258,10 +267,8 @@ class FeaturesExtractor:
f"Constraint features must be a list of numbers. "
f"Found {type(f).__name__} instead for cname={cname}."
)
user_features.append(list(cf))
else:
user_features.append(None)
categories.append(None)
cf = list(cf)
user_features.append(cf)
if has_static_lazy:
lazy.append(instance.is_constraint_lazy(cname))
else:

@ -4,7 +4,7 @@
import logging
from abc import ABC, abstractmethod
from typing import Any, List, Optional, Hashable, TYPE_CHECKING, Dict
from typing import Any, List, Hashable, TYPE_CHECKING, Dict
from miplearn.features import Sample
@ -96,11 +96,11 @@ class Instance(ABC):
"""
return {}
def get_constraint_features(self, cid: str) -> List[float]:
return [0.0]
def get_constraint_features(self) -> Dict[str, List[float]]:
return {}
def get_constraint_category(self, cid: str) -> Optional[Hashable]:
return cid
def get_constraint_categories(self) -> Dict[str, Hashable]:
return {}
def has_static_lazy_constraints(self) -> bool:
return False

@ -56,14 +56,14 @@ class PickleGzInstance(Instance):
return self.instance.get_variable_categories()
@overrides
def get_constraint_features(self, cid: str) -> Optional[List[float]]:
def get_constraint_features(self) -> Dict[str, List[float]]:
assert self.instance is not None
return self.instance.get_constraint_features(cid)
return self.instance.get_constraint_features()
@overrides
def get_constraint_category(self, cid: str) -> Optional[Hashable]:
def get_constraint_categories(self) -> Dict[str, Hashable]:
assert self.instance is not None
return self.instance.get_constraint_category(cid)
return self.instance.get_constraint_categories()
@overrides
def has_static_lazy_constraints(self) -> bool:

@ -6,7 +6,6 @@ from unittest.mock import Mock
import numpy as np
import pytest
from numpy.testing import assert_array_equal
from miplearn.classifiers import Classifier
from miplearn.classifiers.threshold import MinProbabilityThreshold
@ -42,21 +41,21 @@ def training_instances() -> List[Instance]:
instances[0].samples[1].after_load.instance.to_list = Mock( # type: ignore
return_value=[5.0]
)
instances[0].get_constraint_category = Mock( # type: ignore
side_effect=lambda cid: {
instances[0].get_constraint_categories = Mock( # type: ignore
return_value={
"c1": "type-a",
"c2": "type-a",
"c3": "type-b",
"c4": "type-b",
}[cid]
}
)
instances[0].get_constraint_features = Mock( # type: ignore
side_effect=lambda cid: {
return_value={
"c1": [1.0, 2.0, 3.0],
"c2": [4.0, 5.0, 6.0],
"c3": [1.0, 2.0],
"c4": [3.0, 4.0],
}[cid]
}
)
instances[1].samples = [
Sample(
@ -67,20 +66,20 @@ def training_instances() -> List[Instance]:
instances[1].samples[0].after_load.instance.to_list = Mock( # type: ignore
return_value=[8.0]
)
instances[1].get_constraint_category = Mock( # type: ignore
side_effect=lambda cid: {
instances[1].get_constraint_categories = Mock( # type: ignore
return_value={
"c1": None,
"c2": "type-a",
"c3": "type-b",
"c4": "type-b",
}[cid]
}
)
instances[1].get_constraint_features = Mock( # type: ignore
side_effect=lambda cid: {
return_value={
"c2": [7.0, 8.0, 9.0],
"c3": [5.0, 6.0],
"c4": [7.0, 8.0],
}[cid]
}
)
return instances

@ -83,7 +83,7 @@ def test_knapsack() -> None:
sa_rhs_up=[2.0],
senses=["="],
slacks=[0.0],
user_features=[[0.0]],
user_features=[None],
),
)
assert_equals(

Loading…
Cancel
Save