Request constraint features/categories in bulk

master
Alinson S. Xavier 4 years ago
parent 8118ab4110
commit a5092cc2b9
No known key found for this signature in database
GPG Key ID: DCA0DAD4D2F58624

@ -3,7 +3,7 @@
# Released under the modified BSD license. See COPYING.md for more details. # Released under the modified BSD license. See COPYING.md for more details.
import logging import logging
from typing import Dict, Hashable, List, Tuple, Optional, Any, FrozenSet, Set from typing import Dict, Hashable, List, Tuple, Optional, Any, Set
import numpy as np import numpy as np
from overrides import overrides from overrides import overrides
@ -50,9 +50,14 @@ class DynamicConstraintsComponent(Component):
x: Dict[Hashable, List[List[float]]] = {} x: Dict[Hashable, List[List[float]]] = {}
y: Dict[Hashable, List[List[bool]]] = {} y: Dict[Hashable, List[List[bool]]] = {}
cids: Dict[Hashable, List[str]] = {} cids: Dict[Hashable, List[str]] = {}
constr_categories_dict = instance.get_constraint_categories()
constr_features_dict = instance.get_constraint_features()
for cid in self.known_cids: for cid in self.known_cids:
# Initialize categories # Initialize categories
category = instance.get_constraint_category(cid) if cid in constr_categories_dict:
category = constr_categories_dict[cid]
else:
category = cid
if category is None: if category is None:
continue continue
if category not in x: if category not in x:
@ -65,7 +70,8 @@ class DynamicConstraintsComponent(Component):
assert sample.after_load is not None assert sample.after_load is not None
assert sample.after_load.instance is not None assert sample.after_load.instance is not None
features.extend(sample.after_load.instance.to_list()) features.extend(sample.after_load.instance.to_list())
features.extend(instance.get_constraint_features(cid)) if cid in constr_features_dict:
features.extend(constr_features_dict[cid])
for ci in features: for ci in features:
assert isinstance(ci, float), ( assert isinstance(ci, float), (
f"Constraint features must be a list of floats. " f"Constraint features must be a list of floats. "
@ -164,10 +170,11 @@ class DynamicConstraintsComponent(Component):
tn: Dict[Hashable, int] = {} tn: Dict[Hashable, int] = {}
fp: Dict[Hashable, int] = {} fp: Dict[Hashable, int] = {}
fn: Dict[Hashable, int] = {} fn: Dict[Hashable, int] = {}
constr_categories_dict = instance.get_constraint_categories()
for cid in self.known_cids: for cid in self.known_cids:
category = instance.get_constraint_category(cid) if cid not in constr_categories_dict:
if category is None:
continue continue
category = constr_categories_dict[cid]
if category not in tp.keys(): if category not in tp.keys():
tp[category] = 0 tp[category] = 0
tn[category] = 0 tn[category] = 0

@ -237,16 +237,25 @@ class FeaturesExtractor:
user_features: List[Optional[List[float]]] = [] user_features: List[Optional[List[float]]] = []
categories: List[Optional[Hashable]] = [] categories: List[Optional[Hashable]] = []
lazy: List[bool] = [] lazy: List[bool] = []
constr_categories_dict = instance.get_constraint_categories()
constr_features_dict = instance.get_constraint_features()
for (cidx, cname) in enumerate(features.constraints.names): for (cidx, cname) in enumerate(features.constraints.names):
cf: Optional[List[float]] = None category: Optional[Hashable] = cname
category: Optional[Hashable] = instance.get_constraint_category(cname) if cname in constr_categories_dict:
if category is not None: category = constr_categories_dict[cname]
categories.append(category) if category is None:
user_features.append(None)
categories.append(None)
continue
assert isinstance(category, collections.Hashable), ( assert isinstance(category, collections.Hashable), (
f"Constraint category must be hashable. " f"Constraint category must be hashable. "
f"Found {type(category).__name__} instead for cname={cname}.", f"Found {type(category).__name__} instead for cname={cname}.",
) )
cf = instance.get_constraint_features(cname) categories.append(category)
cf: Optional[List[float]] = None
if cname in constr_features_dict:
cf = constr_features_dict[cname]
if isinstance(cf, np.ndarray): if isinstance(cf, np.ndarray):
cf = cf.tolist() cf = cf.tolist()
assert isinstance(cf, list), ( assert isinstance(cf, list), (
@ -258,10 +267,8 @@ class FeaturesExtractor:
f"Constraint features must be a list of numbers. " f"Constraint features must be a list of numbers. "
f"Found {type(f).__name__} instead for cname={cname}." f"Found {type(f).__name__} instead for cname={cname}."
) )
user_features.append(list(cf)) cf = list(cf)
else: user_features.append(cf)
user_features.append(None)
categories.append(None)
if has_static_lazy: if has_static_lazy:
lazy.append(instance.is_constraint_lazy(cname)) lazy.append(instance.is_constraint_lazy(cname))
else: else:

@ -4,7 +4,7 @@
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, List, Optional, Hashable, TYPE_CHECKING, Dict from typing import Any, List, Hashable, TYPE_CHECKING, Dict
from miplearn.features import Sample from miplearn.features import Sample
@ -96,11 +96,11 @@ class Instance(ABC):
""" """
return {} return {}
def get_constraint_features(self, cid: str) -> List[float]: def get_constraint_features(self) -> Dict[str, List[float]]:
return [0.0] return {}
def get_constraint_category(self, cid: str) -> Optional[Hashable]: def get_constraint_categories(self) -> Dict[str, Hashable]:
return cid return {}
def has_static_lazy_constraints(self) -> bool: def has_static_lazy_constraints(self) -> bool:
return False return False

@ -56,14 +56,14 @@ class PickleGzInstance(Instance):
return self.instance.get_variable_categories() return self.instance.get_variable_categories()
@overrides @overrides
def get_constraint_features(self, cid: str) -> Optional[List[float]]: def get_constraint_features(self) -> Dict[str, List[float]]:
assert self.instance is not None assert self.instance is not None
return self.instance.get_constraint_features(cid) return self.instance.get_constraint_features()
@overrides @overrides
def get_constraint_category(self, cid: str) -> Optional[Hashable]: def get_constraint_categories(self) -> Dict[str, Hashable]:
assert self.instance is not None assert self.instance is not None
return self.instance.get_constraint_category(cid) return self.instance.get_constraint_categories()
@overrides @overrides
def has_static_lazy_constraints(self) -> bool: def has_static_lazy_constraints(self) -> bool:

@ -6,7 +6,6 @@ from unittest.mock import Mock
import numpy as np import numpy as np
import pytest import pytest
from numpy.testing import assert_array_equal
from miplearn.classifiers import Classifier from miplearn.classifiers import Classifier
from miplearn.classifiers.threshold import MinProbabilityThreshold from miplearn.classifiers.threshold import MinProbabilityThreshold
@ -42,21 +41,21 @@ def training_instances() -> List[Instance]:
instances[0].samples[1].after_load.instance.to_list = Mock( # type: ignore instances[0].samples[1].after_load.instance.to_list = Mock( # type: ignore
return_value=[5.0] return_value=[5.0]
) )
instances[0].get_constraint_category = Mock( # type: ignore instances[0].get_constraint_categories = Mock( # type: ignore
side_effect=lambda cid: { return_value={
"c1": "type-a", "c1": "type-a",
"c2": "type-a", "c2": "type-a",
"c3": "type-b", "c3": "type-b",
"c4": "type-b", "c4": "type-b",
}[cid] }
) )
instances[0].get_constraint_features = Mock( # type: ignore instances[0].get_constraint_features = Mock( # type: ignore
side_effect=lambda cid: { return_value={
"c1": [1.0, 2.0, 3.0], "c1": [1.0, 2.0, 3.0],
"c2": [4.0, 5.0, 6.0], "c2": [4.0, 5.0, 6.0],
"c3": [1.0, 2.0], "c3": [1.0, 2.0],
"c4": [3.0, 4.0], "c4": [3.0, 4.0],
}[cid] }
) )
instances[1].samples = [ instances[1].samples = [
Sample( Sample(
@ -67,20 +66,20 @@ def training_instances() -> List[Instance]:
instances[1].samples[0].after_load.instance.to_list = Mock( # type: ignore instances[1].samples[0].after_load.instance.to_list = Mock( # type: ignore
return_value=[8.0] return_value=[8.0]
) )
instances[1].get_constraint_category = Mock( # type: ignore instances[1].get_constraint_categories = Mock( # type: ignore
side_effect=lambda cid: { return_value={
"c1": None, "c1": None,
"c2": "type-a", "c2": "type-a",
"c3": "type-b", "c3": "type-b",
"c4": "type-b", "c4": "type-b",
}[cid] }
) )
instances[1].get_constraint_features = Mock( # type: ignore instances[1].get_constraint_features = Mock( # type: ignore
side_effect=lambda cid: { return_value={
"c2": [7.0, 8.0, 9.0], "c2": [7.0, 8.0, 9.0],
"c3": [5.0, 6.0], "c3": [5.0, 6.0],
"c4": [7.0, 8.0], "c4": [7.0, 8.0],
}[cid] }
) )
return instances return instances

@ -83,7 +83,7 @@ def test_knapsack() -> None:
sa_rhs_up=[2.0], sa_rhs_up=[2.0],
senses=["="], senses=["="],
slacks=[0.0], slacks=[0.0],
user_features=[[0.0]], user_features=[None],
), ),
) )
assert_equals( assert_equals(

Loading…
Cancel
Save