mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Replace Hashable by str
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
from typing import Any, List, TYPE_CHECKING, Tuple, Dict, Hashable, Optional
|
||||
from typing import Any, List, TYPE_CHECKING, Tuple, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
from p_tqdm import p_umap
|
||||
@@ -101,8 +101,8 @@ class Component:
|
||||
|
||||
def fit_xy(
|
||||
self,
|
||||
x: Dict[Hashable, np.ndarray],
|
||||
y: Dict[Hashable, np.ndarray],
|
||||
x: Dict[str, np.ndarray],
|
||||
y: Dict[str, np.ndarray],
|
||||
) -> None:
|
||||
"""
|
||||
Given two dictionaries x and y, mapping the name of the category to matrices
|
||||
@@ -152,7 +152,7 @@ class Component:
|
||||
self,
|
||||
instance: Optional[Instance],
|
||||
sample: Sample,
|
||||
) -> Dict[Hashable, Dict[str, float]]:
|
||||
) -> Dict[str, Dict[str, float]]:
|
||||
return {}
|
||||
|
||||
def sample_xy(
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
import logging
|
||||
from typing import Dict, Hashable, List, Tuple, Optional, Any, Set
|
||||
from typing import Dict, List, Tuple, Optional, Any, Set
|
||||
|
||||
import numpy as np
|
||||
from overrides import overrides
|
||||
@@ -32,8 +32,8 @@ class DynamicConstraintsComponent(Component):
|
||||
assert isinstance(classifier, Classifier)
|
||||
self.threshold_prototype: Threshold = threshold
|
||||
self.classifier_prototype: Classifier = classifier
|
||||
self.classifiers: Dict[Hashable, Classifier] = {}
|
||||
self.thresholds: Dict[Hashable, Threshold] = {}
|
||||
self.classifiers: Dict[str, Classifier] = {}
|
||||
self.thresholds: Dict[str, Threshold] = {}
|
||||
self.known_cids: List[str] = []
|
||||
self.attr = attr
|
||||
|
||||
@@ -42,14 +42,14 @@ class DynamicConstraintsComponent(Component):
|
||||
instance: Optional[Instance],
|
||||
sample: Sample,
|
||||
) -> Tuple[
|
||||
Dict[Hashable, List[List[float]]],
|
||||
Dict[Hashable, List[List[bool]]],
|
||||
Dict[Hashable, List[str]],
|
||||
Dict[str, List[List[float]]],
|
||||
Dict[str, List[List[bool]]],
|
||||
Dict[str, List[str]],
|
||||
]:
|
||||
assert instance is not None
|
||||
x: Dict[Hashable, List[List[float]]] = {}
|
||||
y: Dict[Hashable, List[List[bool]]] = {}
|
||||
cids: Dict[Hashable, List[str]] = {}
|
||||
x: Dict[str, List[List[float]]] = {}
|
||||
y: Dict[str, List[List[bool]]] = {}
|
||||
cids: Dict[str, List[str]] = {}
|
||||
constr_categories_dict = instance.get_constraint_categories()
|
||||
constr_features_dict = instance.get_constraint_features()
|
||||
instance_features = sample.get("instance_features_user")
|
||||
@@ -111,8 +111,8 @@ class DynamicConstraintsComponent(Component):
|
||||
self,
|
||||
instance: Instance,
|
||||
sample: Sample,
|
||||
) -> List[Hashable]:
|
||||
pred: List[Hashable] = []
|
||||
) -> List[str]:
|
||||
pred: List[str] = []
|
||||
if len(self.known_cids) == 0:
|
||||
logger.info("Classifiers not fitted. Skipping.")
|
||||
return pred
|
||||
@@ -137,8 +137,8 @@ class DynamicConstraintsComponent(Component):
|
||||
@overrides
|
||||
def fit_xy(
|
||||
self,
|
||||
x: Dict[Hashable, np.ndarray],
|
||||
y: Dict[Hashable, np.ndarray],
|
||||
x: Dict[str, np.ndarray],
|
||||
y: Dict[str, np.ndarray],
|
||||
) -> None:
|
||||
for category in x.keys():
|
||||
self.classifiers[category] = self.classifier_prototype.clone()
|
||||
@@ -153,14 +153,14 @@ class DynamicConstraintsComponent(Component):
|
||||
self,
|
||||
instance: Instance,
|
||||
sample: Sample,
|
||||
) -> Dict[Hashable, Dict[str, float]]:
|
||||
) -> Dict[str, Dict[str, float]]:
|
||||
actual = sample.get(self.attr)
|
||||
assert actual is not None
|
||||
pred = set(self.sample_predict(instance, sample))
|
||||
tp: Dict[Hashable, int] = {}
|
||||
tn: Dict[Hashable, int] = {}
|
||||
fp: Dict[Hashable, int] = {}
|
||||
fn: Dict[Hashable, int] = {}
|
||||
tp: Dict[str, int] = {}
|
||||
tn: Dict[str, int] = {}
|
||||
fp: Dict[str, int] = {}
|
||||
fn: Dict[str, int] = {}
|
||||
constr_categories_dict = instance.get_constraint_categories()
|
||||
for cid in self.known_cids:
|
||||
if cid not in constr_categories_dict:
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, TYPE_CHECKING, Hashable, Tuple, Any, Optional, Set
|
||||
from typing import Dict, List, TYPE_CHECKING, Tuple, Any, Optional, Set
|
||||
|
||||
import numpy as np
|
||||
from overrides import overrides
|
||||
@@ -41,11 +41,11 @@ class DynamicLazyConstraintsComponent(Component):
|
||||
self.classifiers = self.dynamic.classifiers
|
||||
self.thresholds = self.dynamic.thresholds
|
||||
self.known_cids = self.dynamic.known_cids
|
||||
self.lazy_enforced: Set[Hashable] = set()
|
||||
self.lazy_enforced: Set[str] = set()
|
||||
|
||||
@staticmethod
|
||||
def enforce(
|
||||
cids: List[Hashable],
|
||||
cids: List[str],
|
||||
instance: Instance,
|
||||
model: Any,
|
||||
solver: "LearningSolver",
|
||||
@@ -117,7 +117,7 @@ class DynamicLazyConstraintsComponent(Component):
|
||||
self,
|
||||
instance: Instance,
|
||||
sample: Sample,
|
||||
) -> List[Hashable]:
|
||||
) -> List[str]:
|
||||
return self.dynamic.sample_predict(instance, sample)
|
||||
|
||||
@overrides
|
||||
@@ -127,8 +127,8 @@ class DynamicLazyConstraintsComponent(Component):
|
||||
@overrides
|
||||
def fit_xy(
|
||||
self,
|
||||
x: Dict[Hashable, np.ndarray],
|
||||
y: Dict[Hashable, np.ndarray],
|
||||
x: Dict[str, np.ndarray],
|
||||
y: Dict[str, np.ndarray],
|
||||
) -> None:
|
||||
self.dynamic.fit_xy(x, y)
|
||||
|
||||
@@ -137,5 +137,5 @@ class DynamicLazyConstraintsComponent(Component):
|
||||
self,
|
||||
instance: Instance,
|
||||
sample: Sample,
|
||||
) -> Dict[Hashable, Dict[str, float]]:
|
||||
) -> Dict[str, Dict[str, float]]:
|
||||
return self.dynamic.sample_evaluate(instance, sample)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
import logging
|
||||
from typing import Any, TYPE_CHECKING, Hashable, Set, Tuple, Dict, List, Optional
|
||||
from typing import Any, TYPE_CHECKING, Set, Tuple, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
from overrides import overrides
|
||||
@@ -34,7 +34,7 @@ class UserCutsComponent(Component):
|
||||
threshold=threshold,
|
||||
attr="user_cuts_enforced",
|
||||
)
|
||||
self.enforced: Set[Hashable] = set()
|
||||
self.enforced: Set[str] = set()
|
||||
self.n_added_in_callback = 0
|
||||
|
||||
@overrides
|
||||
@@ -71,7 +71,7 @@ class UserCutsComponent(Component):
|
||||
for cid in cids:
|
||||
if cid in self.enforced:
|
||||
continue
|
||||
assert isinstance(cid, Hashable)
|
||||
assert isinstance(cid, str)
|
||||
instance.enforce_user_cut(solver.internal_solver, model, cid)
|
||||
self.enforced.add(cid)
|
||||
self.n_added_in_callback += 1
|
||||
@@ -110,7 +110,7 @@ class UserCutsComponent(Component):
|
||||
self,
|
||||
instance: "Instance",
|
||||
sample: Sample,
|
||||
) -> List[Hashable]:
|
||||
) -> List[str]:
|
||||
return self.dynamic.sample_predict(instance, sample)
|
||||
|
||||
@overrides
|
||||
@@ -120,8 +120,8 @@ class UserCutsComponent(Component):
|
||||
@overrides
|
||||
def fit_xy(
|
||||
self,
|
||||
x: Dict[Hashable, np.ndarray],
|
||||
y: Dict[Hashable, np.ndarray],
|
||||
x: Dict[str, np.ndarray],
|
||||
y: Dict[str, np.ndarray],
|
||||
) -> None:
|
||||
self.dynamic.fit_xy(x, y)
|
||||
|
||||
@@ -130,5 +130,5 @@ class UserCutsComponent(Component):
|
||||
self,
|
||||
instance: "Instance",
|
||||
sample: Sample,
|
||||
) -> Dict[Hashable, Dict[str, float]]:
|
||||
) -> Dict[str, Dict[str, float]]:
|
||||
return self.dynamic.sample_evaluate(instance, sample)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
import logging
|
||||
from typing import List, Dict, Any, TYPE_CHECKING, Tuple, Hashable, Optional
|
||||
from typing import List, Dict, Any, TYPE_CHECKING, Tuple, Optional
|
||||
|
||||
import numpy as np
|
||||
from overrides import overrides
|
||||
@@ -53,8 +53,8 @@ class ObjectiveValueComponent(Component):
|
||||
@overrides
|
||||
def fit_xy(
|
||||
self,
|
||||
x: Dict[Hashable, np.ndarray],
|
||||
y: Dict[Hashable, np.ndarray],
|
||||
x: Dict[str, np.ndarray],
|
||||
y: Dict[str, np.ndarray],
|
||||
) -> None:
|
||||
for c in ["Upper bound", "Lower bound"]:
|
||||
if c in y:
|
||||
@@ -76,20 +76,20 @@ class ObjectiveValueComponent(Component):
|
||||
self,
|
||||
_: Optional[Instance],
|
||||
sample: Sample,
|
||||
) -> Tuple[Dict[Hashable, List[List[float]]], Dict[Hashable, List[List[float]]]]:
|
||||
) -> Tuple[Dict[str, List[List[float]]], Dict[str, List[List[float]]]]:
|
||||
lp_instance_features = sample.get("lp_instance_features")
|
||||
if lp_instance_features is None:
|
||||
lp_instance_features = sample.get("instance_features_user")
|
||||
assert lp_instance_features is not None
|
||||
|
||||
# Features
|
||||
x: Dict[Hashable, List[List[float]]] = {
|
||||
x: Dict[str, List[List[float]]] = {
|
||||
"Upper bound": [lp_instance_features],
|
||||
"Lower bound": [lp_instance_features],
|
||||
}
|
||||
|
||||
# Labels
|
||||
y: Dict[Hashable, List[List[float]]] = {}
|
||||
y: Dict[str, List[List[float]]] = {}
|
||||
mip_lower_bound = sample.get("mip_lower_bound")
|
||||
mip_upper_bound = sample.get("mip_upper_bound")
|
||||
if mip_lower_bound is not None:
|
||||
@@ -104,7 +104,7 @@ class ObjectiveValueComponent(Component):
|
||||
self,
|
||||
instance: Instance,
|
||||
sample: Sample,
|
||||
) -> Dict[Hashable, Dict[str, float]]:
|
||||
) -> Dict[str, Dict[str, float]]:
|
||||
def compare(y_pred: float, y_actual: float) -> Dict[str, float]:
|
||||
err = np.round(abs(y_pred - y_actual), 8)
|
||||
return {
|
||||
@@ -114,7 +114,7 @@ class ObjectiveValueComponent(Component):
|
||||
"Relative error": err / y_actual,
|
||||
}
|
||||
|
||||
result: Dict[Hashable, Dict[str, float]] = {}
|
||||
result: Dict[str, Dict[str, float]] = {}
|
||||
pred = self.sample_predict(sample)
|
||||
actual_ub = sample.get("mip_upper_bound")
|
||||
actual_lb = sample.get("mip_lower_bound")
|
||||
|
||||
@@ -3,15 +3,7 @@
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
import logging
|
||||
from typing import (
|
||||
Dict,
|
||||
List,
|
||||
Hashable,
|
||||
Any,
|
||||
TYPE_CHECKING,
|
||||
Tuple,
|
||||
Optional,
|
||||
)
|
||||
from typing import Dict, List, Any, TYPE_CHECKING, Tuple, Optional
|
||||
|
||||
import numpy as np
|
||||
from overrides import overrides
|
||||
@@ -55,8 +47,8 @@ class PrimalSolutionComponent(Component):
|
||||
assert isinstance(threshold, Threshold)
|
||||
assert mode in ["exact", "heuristic"]
|
||||
self.mode = mode
|
||||
self.classifiers: Dict[Hashable, Classifier] = {}
|
||||
self.thresholds: Dict[Hashable, Threshold] = {}
|
||||
self.classifiers: Dict[str, Classifier] = {}
|
||||
self.thresholds: Dict[str, Threshold] = {}
|
||||
self.threshold_prototype = threshold
|
||||
self.classifier_prototype = classifier
|
||||
|
||||
@@ -128,7 +120,7 @@ class PrimalSolutionComponent(Component):
|
||||
|
||||
# Convert y_pred into solution
|
||||
solution: Solution = {v: None for v in var_names}
|
||||
category_offset: Dict[Hashable, int] = {cat: 0 for cat in x.keys()}
|
||||
category_offset: Dict[str, int] = {cat: 0 for cat in x.keys()}
|
||||
for (i, var_name) in enumerate(var_names):
|
||||
category = var_categories[i]
|
||||
if category not in category_offset:
|
||||
@@ -194,7 +186,7 @@ class PrimalSolutionComponent(Component):
|
||||
self,
|
||||
_: Optional[Instance],
|
||||
sample: Sample,
|
||||
) -> Dict[Hashable, Dict[str, float]]:
|
||||
) -> Dict[str, Dict[str, float]]:
|
||||
mip_var_values = sample.get("mip_var_values")
|
||||
var_names = sample.get("var_names")
|
||||
assert mip_var_values is not None
|
||||
@@ -221,13 +213,13 @@ class PrimalSolutionComponent(Component):
|
||||
pred_one_negative = vars_all - pred_one_positive
|
||||
pred_zero_negative = vars_all - pred_zero_positive
|
||||
return {
|
||||
0: classifier_evaluation_dict(
|
||||
"0": classifier_evaluation_dict(
|
||||
tp=len(pred_zero_positive & vars_zero),
|
||||
tn=len(pred_zero_negative & vars_one),
|
||||
fp=len(pred_zero_positive & vars_one),
|
||||
fn=len(pred_zero_negative & vars_zero),
|
||||
),
|
||||
1: classifier_evaluation_dict(
|
||||
"1": classifier_evaluation_dict(
|
||||
tp=len(pred_one_positive & vars_one),
|
||||
tn=len(pred_one_negative & vars_zero),
|
||||
fp=len(pred_one_positive & vars_zero),
|
||||
@@ -238,8 +230,8 @@ class PrimalSolutionComponent(Component):
|
||||
@overrides
|
||||
def fit_xy(
|
||||
self,
|
||||
x: Dict[Hashable, np.ndarray],
|
||||
y: Dict[Hashable, np.ndarray],
|
||||
x: Dict[str, np.ndarray],
|
||||
y: Dict[str, np.ndarray],
|
||||
) -> None:
|
||||
for category in x.keys():
|
||||
clf = self.classifier_prototype.clone()
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
import logging
|
||||
from typing import Dict, Tuple, List, Hashable, Any, TYPE_CHECKING, Set, Optional
|
||||
from typing import Dict, Tuple, List, Any, TYPE_CHECKING, Set, Optional
|
||||
|
||||
import numpy as np
|
||||
from overrides import overrides
|
||||
@@ -44,11 +44,11 @@ class StaticLazyConstraintsComponent(Component):
|
||||
assert isinstance(classifier, Classifier)
|
||||
self.classifier_prototype: Classifier = classifier
|
||||
self.threshold_prototype: Threshold = threshold
|
||||
self.classifiers: Dict[Hashable, Classifier] = {}
|
||||
self.thresholds: Dict[Hashable, Threshold] = {}
|
||||
self.classifiers: Dict[str, Classifier] = {}
|
||||
self.thresholds: Dict[str, Threshold] = {}
|
||||
self.pool: Constraints = Constraints()
|
||||
self.violation_tolerance: float = violation_tolerance
|
||||
self.enforced_cids: Set[Hashable] = set()
|
||||
self.enforced_cids: Set[str] = set()
|
||||
self.n_restored: int = 0
|
||||
self.n_iterations: int = 0
|
||||
|
||||
@@ -105,8 +105,8 @@ class StaticLazyConstraintsComponent(Component):
|
||||
@overrides
|
||||
def fit_xy(
|
||||
self,
|
||||
x: Dict[Hashable, np.ndarray],
|
||||
y: Dict[Hashable, np.ndarray],
|
||||
x: Dict[str, np.ndarray],
|
||||
y: Dict[str, np.ndarray],
|
||||
) -> None:
|
||||
for c in y.keys():
|
||||
assert c in x
|
||||
@@ -136,9 +136,9 @@ class StaticLazyConstraintsComponent(Component):
|
||||
) -> None:
|
||||
self._check_and_add(solver)
|
||||
|
||||
def sample_predict(self, sample: Sample) -> List[Hashable]:
|
||||
def sample_predict(self, sample: Sample) -> List[str]:
|
||||
x, y, cids = self._sample_xy_with_cids(sample)
|
||||
enforced_cids: List[Hashable] = []
|
||||
enforced_cids: List[str] = []
|
||||
for category in x.keys():
|
||||
if category not in self.classifiers:
|
||||
continue
|
||||
@@ -156,7 +156,7 @@ class StaticLazyConstraintsComponent(Component):
|
||||
self,
|
||||
_: Optional[Instance],
|
||||
sample: Sample,
|
||||
) -> Tuple[Dict[Hashable, List[List[float]]], Dict[Hashable, List[List[float]]]]:
|
||||
) -> Tuple[Dict[str, List[List[float]]], Dict[str, List[List[float]]]]:
|
||||
x, y, __ = self._sample_xy_with_cids(sample)
|
||||
return x, y
|
||||
|
||||
@@ -197,13 +197,13 @@ class StaticLazyConstraintsComponent(Component):
|
||||
def _sample_xy_with_cids(
|
||||
self, sample: Sample
|
||||
) -> Tuple[
|
||||
Dict[Hashable, List[List[float]]],
|
||||
Dict[Hashable, List[List[float]]],
|
||||
Dict[Hashable, List[str]],
|
||||
Dict[str, List[List[float]]],
|
||||
Dict[str, List[List[float]]],
|
||||
Dict[str, List[str]],
|
||||
]:
|
||||
x: Dict[Hashable, List[List[float]]] = {}
|
||||
y: Dict[Hashable, List[List[float]]] = {}
|
||||
cids: Dict[Hashable, List[str]] = {}
|
||||
x: Dict[str, List[List[float]]] = {}
|
||||
y: Dict[str, List[List[float]]] = {}
|
||||
cids: Dict[str, List[str]] = {}
|
||||
instance_features = sample.get("instance_features_user")
|
||||
constr_features = sample.get("lp_constr_features")
|
||||
constr_names = sample.get("constr_names")
|
||||
|
||||
Reference in New Issue
Block a user