mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Request constraint features/categories in bulk
This commit is contained in:
@@ -3,7 +3,7 @@
|
|||||||
# Released under the modified BSD license. See COPYING.md for more details.
|
# Released under the modified BSD license. See COPYING.md for more details.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, Hashable, List, Tuple, Optional, Any, FrozenSet, Set
|
from typing import Dict, Hashable, List, Tuple, Optional, Any, Set
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from overrides import overrides
|
from overrides import overrides
|
||||||
@@ -50,9 +50,14 @@ class DynamicConstraintsComponent(Component):
|
|||||||
x: Dict[Hashable, List[List[float]]] = {}
|
x: Dict[Hashable, List[List[float]]] = {}
|
||||||
y: Dict[Hashable, List[List[bool]]] = {}
|
y: Dict[Hashable, List[List[bool]]] = {}
|
||||||
cids: Dict[Hashable, List[str]] = {}
|
cids: Dict[Hashable, List[str]] = {}
|
||||||
|
constr_categories_dict = instance.get_constraint_categories()
|
||||||
|
constr_features_dict = instance.get_constraint_features()
|
||||||
for cid in self.known_cids:
|
for cid in self.known_cids:
|
||||||
# Initialize categories
|
# Initialize categories
|
||||||
category = instance.get_constraint_category(cid)
|
if cid in constr_categories_dict:
|
||||||
|
category = constr_categories_dict[cid]
|
||||||
|
else:
|
||||||
|
category = cid
|
||||||
if category is None:
|
if category is None:
|
||||||
continue
|
continue
|
||||||
if category not in x:
|
if category not in x:
|
||||||
@@ -65,7 +70,8 @@ class DynamicConstraintsComponent(Component):
|
|||||||
assert sample.after_load is not None
|
assert sample.after_load is not None
|
||||||
assert sample.after_load.instance is not None
|
assert sample.after_load.instance is not None
|
||||||
features.extend(sample.after_load.instance.to_list())
|
features.extend(sample.after_load.instance.to_list())
|
||||||
features.extend(instance.get_constraint_features(cid))
|
if cid in constr_features_dict:
|
||||||
|
features.extend(constr_features_dict[cid])
|
||||||
for ci in features:
|
for ci in features:
|
||||||
assert isinstance(ci, float), (
|
assert isinstance(ci, float), (
|
||||||
f"Constraint features must be a list of floats. "
|
f"Constraint features must be a list of floats. "
|
||||||
@@ -164,10 +170,11 @@ class DynamicConstraintsComponent(Component):
|
|||||||
tn: Dict[Hashable, int] = {}
|
tn: Dict[Hashable, int] = {}
|
||||||
fp: Dict[Hashable, int] = {}
|
fp: Dict[Hashable, int] = {}
|
||||||
fn: Dict[Hashable, int] = {}
|
fn: Dict[Hashable, int] = {}
|
||||||
|
constr_categories_dict = instance.get_constraint_categories()
|
||||||
for cid in self.known_cids:
|
for cid in self.known_cids:
|
||||||
category = instance.get_constraint_category(cid)
|
if cid not in constr_categories_dict:
|
||||||
if category is None:
|
|
||||||
continue
|
continue
|
||||||
|
category = constr_categories_dict[cid]
|
||||||
if category not in tp.keys():
|
if category not in tp.keys():
|
||||||
tp[category] = 0
|
tp[category] = 0
|
||||||
tn[category] = 0
|
tn[category] = 0
|
||||||
|
|||||||
@@ -237,16 +237,25 @@ class FeaturesExtractor:
|
|||||||
user_features: List[Optional[List[float]]] = []
|
user_features: List[Optional[List[float]]] = []
|
||||||
categories: List[Optional[Hashable]] = []
|
categories: List[Optional[Hashable]] = []
|
||||||
lazy: List[bool] = []
|
lazy: List[bool] = []
|
||||||
|
constr_categories_dict = instance.get_constraint_categories()
|
||||||
|
constr_features_dict = instance.get_constraint_features()
|
||||||
|
|
||||||
for (cidx, cname) in enumerate(features.constraints.names):
|
for (cidx, cname) in enumerate(features.constraints.names):
|
||||||
|
category: Optional[Hashable] = cname
|
||||||
|
if cname in constr_categories_dict:
|
||||||
|
category = constr_categories_dict[cname]
|
||||||
|
if category is None:
|
||||||
|
user_features.append(None)
|
||||||
|
categories.append(None)
|
||||||
|
continue
|
||||||
|
assert isinstance(category, collections.Hashable), (
|
||||||
|
f"Constraint category must be hashable. "
|
||||||
|
f"Found {type(category).__name__} instead for cname={cname}.",
|
||||||
|
)
|
||||||
|
categories.append(category)
|
||||||
cf: Optional[List[float]] = None
|
cf: Optional[List[float]] = None
|
||||||
category: Optional[Hashable] = instance.get_constraint_category(cname)
|
if cname in constr_features_dict:
|
||||||
if category is not None:
|
cf = constr_features_dict[cname]
|
||||||
categories.append(category)
|
|
||||||
assert isinstance(category, collections.Hashable), (
|
|
||||||
f"Constraint category must be hashable. "
|
|
||||||
f"Found {type(category).__name__} instead for cname={cname}.",
|
|
||||||
)
|
|
||||||
cf = instance.get_constraint_features(cname)
|
|
||||||
if isinstance(cf, np.ndarray):
|
if isinstance(cf, np.ndarray):
|
||||||
cf = cf.tolist()
|
cf = cf.tolist()
|
||||||
assert isinstance(cf, list), (
|
assert isinstance(cf, list), (
|
||||||
@@ -258,10 +267,8 @@ class FeaturesExtractor:
|
|||||||
f"Constraint features must be a list of numbers. "
|
f"Constraint features must be a list of numbers. "
|
||||||
f"Found {type(f).__name__} instead for cname={cname}."
|
f"Found {type(f).__name__} instead for cname={cname}."
|
||||||
)
|
)
|
||||||
user_features.append(list(cf))
|
cf = list(cf)
|
||||||
else:
|
user_features.append(cf)
|
||||||
user_features.append(None)
|
|
||||||
categories.append(None)
|
|
||||||
if has_static_lazy:
|
if has_static_lazy:
|
||||||
lazy.append(instance.is_constraint_lazy(cname))
|
lazy.append(instance.is_constraint_lazy(cname))
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, List, Optional, Hashable, TYPE_CHECKING, Dict
|
from typing import Any, List, Hashable, TYPE_CHECKING, Dict
|
||||||
|
|
||||||
from miplearn.features import Sample
|
from miplearn.features import Sample
|
||||||
|
|
||||||
@@ -96,11 +96,11 @@ class Instance(ABC):
|
|||||||
"""
|
"""
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def get_constraint_features(self, cid: str) -> List[float]:
|
def get_constraint_features(self) -> Dict[str, List[float]]:
|
||||||
return [0.0]
|
return {}
|
||||||
|
|
||||||
def get_constraint_category(self, cid: str) -> Optional[Hashable]:
|
def get_constraint_categories(self) -> Dict[str, Hashable]:
|
||||||
return cid
|
return {}
|
||||||
|
|
||||||
def has_static_lazy_constraints(self) -> bool:
|
def has_static_lazy_constraints(self) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -56,14 +56,14 @@ class PickleGzInstance(Instance):
|
|||||||
return self.instance.get_variable_categories()
|
return self.instance.get_variable_categories()
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def get_constraint_features(self, cid: str) -> Optional[List[float]]:
|
def get_constraint_features(self) -> Dict[str, List[float]]:
|
||||||
assert self.instance is not None
|
assert self.instance is not None
|
||||||
return self.instance.get_constraint_features(cid)
|
return self.instance.get_constraint_features()
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def get_constraint_category(self, cid: str) -> Optional[Hashable]:
|
def get_constraint_categories(self) -> Dict[str, Hashable]:
|
||||||
assert self.instance is not None
|
assert self.instance is not None
|
||||||
return self.instance.get_constraint_category(cid)
|
return self.instance.get_constraint_categories()
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def has_static_lazy_constraints(self) -> bool:
|
def has_static_lazy_constraints(self) -> bool:
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ from unittest.mock import Mock
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
from numpy.testing import assert_array_equal
|
|
||||||
|
|
||||||
from miplearn.classifiers import Classifier
|
from miplearn.classifiers import Classifier
|
||||||
from miplearn.classifiers.threshold import MinProbabilityThreshold
|
from miplearn.classifiers.threshold import MinProbabilityThreshold
|
||||||
@@ -42,21 +41,21 @@ def training_instances() -> List[Instance]:
|
|||||||
instances[0].samples[1].after_load.instance.to_list = Mock( # type: ignore
|
instances[0].samples[1].after_load.instance.to_list = Mock( # type: ignore
|
||||||
return_value=[5.0]
|
return_value=[5.0]
|
||||||
)
|
)
|
||||||
instances[0].get_constraint_category = Mock( # type: ignore
|
instances[0].get_constraint_categories = Mock( # type: ignore
|
||||||
side_effect=lambda cid: {
|
return_value={
|
||||||
"c1": "type-a",
|
"c1": "type-a",
|
||||||
"c2": "type-a",
|
"c2": "type-a",
|
||||||
"c3": "type-b",
|
"c3": "type-b",
|
||||||
"c4": "type-b",
|
"c4": "type-b",
|
||||||
}[cid]
|
}
|
||||||
)
|
)
|
||||||
instances[0].get_constraint_features = Mock( # type: ignore
|
instances[0].get_constraint_features = Mock( # type: ignore
|
||||||
side_effect=lambda cid: {
|
return_value={
|
||||||
"c1": [1.0, 2.0, 3.0],
|
"c1": [1.0, 2.0, 3.0],
|
||||||
"c2": [4.0, 5.0, 6.0],
|
"c2": [4.0, 5.0, 6.0],
|
||||||
"c3": [1.0, 2.0],
|
"c3": [1.0, 2.0],
|
||||||
"c4": [3.0, 4.0],
|
"c4": [3.0, 4.0],
|
||||||
}[cid]
|
}
|
||||||
)
|
)
|
||||||
instances[1].samples = [
|
instances[1].samples = [
|
||||||
Sample(
|
Sample(
|
||||||
@@ -67,20 +66,20 @@ def training_instances() -> List[Instance]:
|
|||||||
instances[1].samples[0].after_load.instance.to_list = Mock( # type: ignore
|
instances[1].samples[0].after_load.instance.to_list = Mock( # type: ignore
|
||||||
return_value=[8.0]
|
return_value=[8.0]
|
||||||
)
|
)
|
||||||
instances[1].get_constraint_category = Mock( # type: ignore
|
instances[1].get_constraint_categories = Mock( # type: ignore
|
||||||
side_effect=lambda cid: {
|
return_value={
|
||||||
"c1": None,
|
"c1": None,
|
||||||
"c2": "type-a",
|
"c2": "type-a",
|
||||||
"c3": "type-b",
|
"c3": "type-b",
|
||||||
"c4": "type-b",
|
"c4": "type-b",
|
||||||
}[cid]
|
}
|
||||||
)
|
)
|
||||||
instances[1].get_constraint_features = Mock( # type: ignore
|
instances[1].get_constraint_features = Mock( # type: ignore
|
||||||
side_effect=lambda cid: {
|
return_value={
|
||||||
"c2": [7.0, 8.0, 9.0],
|
"c2": [7.0, 8.0, 9.0],
|
||||||
"c3": [5.0, 6.0],
|
"c3": [5.0, 6.0],
|
||||||
"c4": [7.0, 8.0],
|
"c4": [7.0, 8.0],
|
||||||
}[cid]
|
}
|
||||||
)
|
)
|
||||||
return instances
|
return instances
|
||||||
|
|
||||||
|
|||||||
@@ -83,7 +83,7 @@ def test_knapsack() -> None:
|
|||||||
sa_rhs_up=[2.0],
|
sa_rhs_up=[2.0],
|
||||||
senses=["="],
|
senses=["="],
|
||||||
slacks=[0.0],
|
slacks=[0.0],
|
||||||
user_features=[[0.0]],
|
user_features=[None],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
assert_equals(
|
assert_equals(
|
||||||
|
|||||||
Reference in New Issue
Block a user