mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Replace Hashable by str
This commit is contained in:
@@ -33,6 +33,7 @@
|
|||||||
)
|
)
|
||||||
```
|
```
|
||||||
- `LazyConstraintComponent` has been renamed to `DynamicLazyConstraintsComponent`.
|
- `LazyConstraintComponent` has been renamed to `DynamicLazyConstraintsComponent`.
|
||||||
|
- Categories, lazy constraints and cutting plane identifiers must now be strings, instead `Hashable`. This change was required for compatibility with HDF5 data format.
|
||||||
|
|
||||||
### Removed
|
### Removed
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
|
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
|
||||||
# Released under the modified BSD license. See COPYING.md for more details.
|
# Released under the modified BSD license. See COPYING.md for more details.
|
||||||
|
|
||||||
from typing import Any, List, TYPE_CHECKING, Tuple, Dict, Hashable, Optional
|
from typing import Any, List, TYPE_CHECKING, Tuple, Dict, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from p_tqdm import p_umap
|
from p_tqdm import p_umap
|
||||||
@@ -101,8 +101,8 @@ class Component:
|
|||||||
|
|
||||||
def fit_xy(
|
def fit_xy(
|
||||||
self,
|
self,
|
||||||
x: Dict[Hashable, np.ndarray],
|
x: Dict[str, np.ndarray],
|
||||||
y: Dict[Hashable, np.ndarray],
|
y: Dict[str, np.ndarray],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Given two dictionaries x and y, mapping the name of the category to matrices
|
Given two dictionaries x and y, mapping the name of the category to matrices
|
||||||
@@ -152,7 +152,7 @@ class Component:
|
|||||||
self,
|
self,
|
||||||
instance: Optional[Instance],
|
instance: Optional[Instance],
|
||||||
sample: Sample,
|
sample: Sample,
|
||||||
) -> Dict[Hashable, Dict[str, float]]:
|
) -> Dict[str, Dict[str, float]]:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def sample_xy(
|
def sample_xy(
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
# Released under the modified BSD license. See COPYING.md for more details.
|
# Released under the modified BSD license. See COPYING.md for more details.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, Hashable, List, Tuple, Optional, Any, Set
|
from typing import Dict, List, Tuple, Optional, Any, Set
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from overrides import overrides
|
from overrides import overrides
|
||||||
@@ -32,8 +32,8 @@ class DynamicConstraintsComponent(Component):
|
|||||||
assert isinstance(classifier, Classifier)
|
assert isinstance(classifier, Classifier)
|
||||||
self.threshold_prototype: Threshold = threshold
|
self.threshold_prototype: Threshold = threshold
|
||||||
self.classifier_prototype: Classifier = classifier
|
self.classifier_prototype: Classifier = classifier
|
||||||
self.classifiers: Dict[Hashable, Classifier] = {}
|
self.classifiers: Dict[str, Classifier] = {}
|
||||||
self.thresholds: Dict[Hashable, Threshold] = {}
|
self.thresholds: Dict[str, Threshold] = {}
|
||||||
self.known_cids: List[str] = []
|
self.known_cids: List[str] = []
|
||||||
self.attr = attr
|
self.attr = attr
|
||||||
|
|
||||||
@@ -42,14 +42,14 @@ class DynamicConstraintsComponent(Component):
|
|||||||
instance: Optional[Instance],
|
instance: Optional[Instance],
|
||||||
sample: Sample,
|
sample: Sample,
|
||||||
) -> Tuple[
|
) -> Tuple[
|
||||||
Dict[Hashable, List[List[float]]],
|
Dict[str, List[List[float]]],
|
||||||
Dict[Hashable, List[List[bool]]],
|
Dict[str, List[List[bool]]],
|
||||||
Dict[Hashable, List[str]],
|
Dict[str, List[str]],
|
||||||
]:
|
]:
|
||||||
assert instance is not None
|
assert instance is not None
|
||||||
x: Dict[Hashable, List[List[float]]] = {}
|
x: Dict[str, List[List[float]]] = {}
|
||||||
y: Dict[Hashable, List[List[bool]]] = {}
|
y: Dict[str, List[List[bool]]] = {}
|
||||||
cids: Dict[Hashable, List[str]] = {}
|
cids: Dict[str, List[str]] = {}
|
||||||
constr_categories_dict = instance.get_constraint_categories()
|
constr_categories_dict = instance.get_constraint_categories()
|
||||||
constr_features_dict = instance.get_constraint_features()
|
constr_features_dict = instance.get_constraint_features()
|
||||||
instance_features = sample.get("instance_features_user")
|
instance_features = sample.get("instance_features_user")
|
||||||
@@ -111,8 +111,8 @@ class DynamicConstraintsComponent(Component):
|
|||||||
self,
|
self,
|
||||||
instance: Instance,
|
instance: Instance,
|
||||||
sample: Sample,
|
sample: Sample,
|
||||||
) -> List[Hashable]:
|
) -> List[str]:
|
||||||
pred: List[Hashable] = []
|
pred: List[str] = []
|
||||||
if len(self.known_cids) == 0:
|
if len(self.known_cids) == 0:
|
||||||
logger.info("Classifiers not fitted. Skipping.")
|
logger.info("Classifiers not fitted. Skipping.")
|
||||||
return pred
|
return pred
|
||||||
@@ -137,8 +137,8 @@ class DynamicConstraintsComponent(Component):
|
|||||||
@overrides
|
@overrides
|
||||||
def fit_xy(
|
def fit_xy(
|
||||||
self,
|
self,
|
||||||
x: Dict[Hashable, np.ndarray],
|
x: Dict[str, np.ndarray],
|
||||||
y: Dict[Hashable, np.ndarray],
|
y: Dict[str, np.ndarray],
|
||||||
) -> None:
|
) -> None:
|
||||||
for category in x.keys():
|
for category in x.keys():
|
||||||
self.classifiers[category] = self.classifier_prototype.clone()
|
self.classifiers[category] = self.classifier_prototype.clone()
|
||||||
@@ -153,14 +153,14 @@ class DynamicConstraintsComponent(Component):
|
|||||||
self,
|
self,
|
||||||
instance: Instance,
|
instance: Instance,
|
||||||
sample: Sample,
|
sample: Sample,
|
||||||
) -> Dict[Hashable, Dict[str, float]]:
|
) -> Dict[str, Dict[str, float]]:
|
||||||
actual = sample.get(self.attr)
|
actual = sample.get(self.attr)
|
||||||
assert actual is not None
|
assert actual is not None
|
||||||
pred = set(self.sample_predict(instance, sample))
|
pred = set(self.sample_predict(instance, sample))
|
||||||
tp: Dict[Hashable, int] = {}
|
tp: Dict[str, int] = {}
|
||||||
tn: Dict[Hashable, int] = {}
|
tn: Dict[str, int] = {}
|
||||||
fp: Dict[Hashable, int] = {}
|
fp: Dict[str, int] = {}
|
||||||
fn: Dict[Hashable, int] = {}
|
fn: Dict[str, int] = {}
|
||||||
constr_categories_dict = instance.get_constraint_categories()
|
constr_categories_dict = instance.get_constraint_categories()
|
||||||
for cid in self.known_cids:
|
for cid in self.known_cids:
|
||||||
if cid not in constr_categories_dict:
|
if cid not in constr_categories_dict:
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
# Released under the modified BSD license. See COPYING.md for more details.
|
# Released under the modified BSD license. See COPYING.md for more details.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, List, TYPE_CHECKING, Hashable, Tuple, Any, Optional, Set
|
from typing import Dict, List, TYPE_CHECKING, Tuple, Any, Optional, Set
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from overrides import overrides
|
from overrides import overrides
|
||||||
@@ -41,11 +41,11 @@ class DynamicLazyConstraintsComponent(Component):
|
|||||||
self.classifiers = self.dynamic.classifiers
|
self.classifiers = self.dynamic.classifiers
|
||||||
self.thresholds = self.dynamic.thresholds
|
self.thresholds = self.dynamic.thresholds
|
||||||
self.known_cids = self.dynamic.known_cids
|
self.known_cids = self.dynamic.known_cids
|
||||||
self.lazy_enforced: Set[Hashable] = set()
|
self.lazy_enforced: Set[str] = set()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def enforce(
|
def enforce(
|
||||||
cids: List[Hashable],
|
cids: List[str],
|
||||||
instance: Instance,
|
instance: Instance,
|
||||||
model: Any,
|
model: Any,
|
||||||
solver: "LearningSolver",
|
solver: "LearningSolver",
|
||||||
@@ -117,7 +117,7 @@ class DynamicLazyConstraintsComponent(Component):
|
|||||||
self,
|
self,
|
||||||
instance: Instance,
|
instance: Instance,
|
||||||
sample: Sample,
|
sample: Sample,
|
||||||
) -> List[Hashable]:
|
) -> List[str]:
|
||||||
return self.dynamic.sample_predict(instance, sample)
|
return self.dynamic.sample_predict(instance, sample)
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
@@ -127,8 +127,8 @@ class DynamicLazyConstraintsComponent(Component):
|
|||||||
@overrides
|
@overrides
|
||||||
def fit_xy(
|
def fit_xy(
|
||||||
self,
|
self,
|
||||||
x: Dict[Hashable, np.ndarray],
|
x: Dict[str, np.ndarray],
|
||||||
y: Dict[Hashable, np.ndarray],
|
y: Dict[str, np.ndarray],
|
||||||
) -> None:
|
) -> None:
|
||||||
self.dynamic.fit_xy(x, y)
|
self.dynamic.fit_xy(x, y)
|
||||||
|
|
||||||
@@ -137,5 +137,5 @@ class DynamicLazyConstraintsComponent(Component):
|
|||||||
self,
|
self,
|
||||||
instance: Instance,
|
instance: Instance,
|
||||||
sample: Sample,
|
sample: Sample,
|
||||||
) -> Dict[Hashable, Dict[str, float]]:
|
) -> Dict[str, Dict[str, float]]:
|
||||||
return self.dynamic.sample_evaluate(instance, sample)
|
return self.dynamic.sample_evaluate(instance, sample)
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
# Released under the modified BSD license. See COPYING.md for more details.
|
# Released under the modified BSD license. See COPYING.md for more details.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, TYPE_CHECKING, Hashable, Set, Tuple, Dict, List, Optional
|
from typing import Any, TYPE_CHECKING, Set, Tuple, Dict, List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from overrides import overrides
|
from overrides import overrides
|
||||||
@@ -34,7 +34,7 @@ class UserCutsComponent(Component):
|
|||||||
threshold=threshold,
|
threshold=threshold,
|
||||||
attr="user_cuts_enforced",
|
attr="user_cuts_enforced",
|
||||||
)
|
)
|
||||||
self.enforced: Set[Hashable] = set()
|
self.enforced: Set[str] = set()
|
||||||
self.n_added_in_callback = 0
|
self.n_added_in_callback = 0
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
@@ -71,7 +71,7 @@ class UserCutsComponent(Component):
|
|||||||
for cid in cids:
|
for cid in cids:
|
||||||
if cid in self.enforced:
|
if cid in self.enforced:
|
||||||
continue
|
continue
|
||||||
assert isinstance(cid, Hashable)
|
assert isinstance(cid, str)
|
||||||
instance.enforce_user_cut(solver.internal_solver, model, cid)
|
instance.enforce_user_cut(solver.internal_solver, model, cid)
|
||||||
self.enforced.add(cid)
|
self.enforced.add(cid)
|
||||||
self.n_added_in_callback += 1
|
self.n_added_in_callback += 1
|
||||||
@@ -110,7 +110,7 @@ class UserCutsComponent(Component):
|
|||||||
self,
|
self,
|
||||||
instance: "Instance",
|
instance: "Instance",
|
||||||
sample: Sample,
|
sample: Sample,
|
||||||
) -> List[Hashable]:
|
) -> List[str]:
|
||||||
return self.dynamic.sample_predict(instance, sample)
|
return self.dynamic.sample_predict(instance, sample)
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
@@ -120,8 +120,8 @@ class UserCutsComponent(Component):
|
|||||||
@overrides
|
@overrides
|
||||||
def fit_xy(
|
def fit_xy(
|
||||||
self,
|
self,
|
||||||
x: Dict[Hashable, np.ndarray],
|
x: Dict[str, np.ndarray],
|
||||||
y: Dict[Hashable, np.ndarray],
|
y: Dict[str, np.ndarray],
|
||||||
) -> None:
|
) -> None:
|
||||||
self.dynamic.fit_xy(x, y)
|
self.dynamic.fit_xy(x, y)
|
||||||
|
|
||||||
@@ -130,5 +130,5 @@ class UserCutsComponent(Component):
|
|||||||
self,
|
self,
|
||||||
instance: "Instance",
|
instance: "Instance",
|
||||||
sample: Sample,
|
sample: Sample,
|
||||||
) -> Dict[Hashable, Dict[str, float]]:
|
) -> Dict[str, Dict[str, float]]:
|
||||||
return self.dynamic.sample_evaluate(instance, sample)
|
return self.dynamic.sample_evaluate(instance, sample)
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
# Released under the modified BSD license. See COPYING.md for more details.
|
# Released under the modified BSD license. See COPYING.md for more details.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Dict, Any, TYPE_CHECKING, Tuple, Hashable, Optional
|
from typing import List, Dict, Any, TYPE_CHECKING, Tuple, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from overrides import overrides
|
from overrides import overrides
|
||||||
@@ -53,8 +53,8 @@ class ObjectiveValueComponent(Component):
|
|||||||
@overrides
|
@overrides
|
||||||
def fit_xy(
|
def fit_xy(
|
||||||
self,
|
self,
|
||||||
x: Dict[Hashable, np.ndarray],
|
x: Dict[str, np.ndarray],
|
||||||
y: Dict[Hashable, np.ndarray],
|
y: Dict[str, np.ndarray],
|
||||||
) -> None:
|
) -> None:
|
||||||
for c in ["Upper bound", "Lower bound"]:
|
for c in ["Upper bound", "Lower bound"]:
|
||||||
if c in y:
|
if c in y:
|
||||||
@@ -76,20 +76,20 @@ class ObjectiveValueComponent(Component):
|
|||||||
self,
|
self,
|
||||||
_: Optional[Instance],
|
_: Optional[Instance],
|
||||||
sample: Sample,
|
sample: Sample,
|
||||||
) -> Tuple[Dict[Hashable, List[List[float]]], Dict[Hashable, List[List[float]]]]:
|
) -> Tuple[Dict[str, List[List[float]]], Dict[str, List[List[float]]]]:
|
||||||
lp_instance_features = sample.get("lp_instance_features")
|
lp_instance_features = sample.get("lp_instance_features")
|
||||||
if lp_instance_features is None:
|
if lp_instance_features is None:
|
||||||
lp_instance_features = sample.get("instance_features_user")
|
lp_instance_features = sample.get("instance_features_user")
|
||||||
assert lp_instance_features is not None
|
assert lp_instance_features is not None
|
||||||
|
|
||||||
# Features
|
# Features
|
||||||
x: Dict[Hashable, List[List[float]]] = {
|
x: Dict[str, List[List[float]]] = {
|
||||||
"Upper bound": [lp_instance_features],
|
"Upper bound": [lp_instance_features],
|
||||||
"Lower bound": [lp_instance_features],
|
"Lower bound": [lp_instance_features],
|
||||||
}
|
}
|
||||||
|
|
||||||
# Labels
|
# Labels
|
||||||
y: Dict[Hashable, List[List[float]]] = {}
|
y: Dict[str, List[List[float]]] = {}
|
||||||
mip_lower_bound = sample.get("mip_lower_bound")
|
mip_lower_bound = sample.get("mip_lower_bound")
|
||||||
mip_upper_bound = sample.get("mip_upper_bound")
|
mip_upper_bound = sample.get("mip_upper_bound")
|
||||||
if mip_lower_bound is not None:
|
if mip_lower_bound is not None:
|
||||||
@@ -104,7 +104,7 @@ class ObjectiveValueComponent(Component):
|
|||||||
self,
|
self,
|
||||||
instance: Instance,
|
instance: Instance,
|
||||||
sample: Sample,
|
sample: Sample,
|
||||||
) -> Dict[Hashable, Dict[str, float]]:
|
) -> Dict[str, Dict[str, float]]:
|
||||||
def compare(y_pred: float, y_actual: float) -> Dict[str, float]:
|
def compare(y_pred: float, y_actual: float) -> Dict[str, float]:
|
||||||
err = np.round(abs(y_pred - y_actual), 8)
|
err = np.round(abs(y_pred - y_actual), 8)
|
||||||
return {
|
return {
|
||||||
@@ -114,7 +114,7 @@ class ObjectiveValueComponent(Component):
|
|||||||
"Relative error": err / y_actual,
|
"Relative error": err / y_actual,
|
||||||
}
|
}
|
||||||
|
|
||||||
result: Dict[Hashable, Dict[str, float]] = {}
|
result: Dict[str, Dict[str, float]] = {}
|
||||||
pred = self.sample_predict(sample)
|
pred = self.sample_predict(sample)
|
||||||
actual_ub = sample.get("mip_upper_bound")
|
actual_ub = sample.get("mip_upper_bound")
|
||||||
actual_lb = sample.get("mip_lower_bound")
|
actual_lb = sample.get("mip_lower_bound")
|
||||||
|
|||||||
@@ -3,15 +3,7 @@
|
|||||||
# Released under the modified BSD license. See COPYING.md for more details.
|
# Released under the modified BSD license. See COPYING.md for more details.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import (
|
from typing import Dict, List, Any, TYPE_CHECKING, Tuple, Optional
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Hashable,
|
|
||||||
Any,
|
|
||||||
TYPE_CHECKING,
|
|
||||||
Tuple,
|
|
||||||
Optional,
|
|
||||||
)
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from overrides import overrides
|
from overrides import overrides
|
||||||
@@ -55,8 +47,8 @@ class PrimalSolutionComponent(Component):
|
|||||||
assert isinstance(threshold, Threshold)
|
assert isinstance(threshold, Threshold)
|
||||||
assert mode in ["exact", "heuristic"]
|
assert mode in ["exact", "heuristic"]
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
self.classifiers: Dict[Hashable, Classifier] = {}
|
self.classifiers: Dict[str, Classifier] = {}
|
||||||
self.thresholds: Dict[Hashable, Threshold] = {}
|
self.thresholds: Dict[str, Threshold] = {}
|
||||||
self.threshold_prototype = threshold
|
self.threshold_prototype = threshold
|
||||||
self.classifier_prototype = classifier
|
self.classifier_prototype = classifier
|
||||||
|
|
||||||
@@ -128,7 +120,7 @@ class PrimalSolutionComponent(Component):
|
|||||||
|
|
||||||
# Convert y_pred into solution
|
# Convert y_pred into solution
|
||||||
solution: Solution = {v: None for v in var_names}
|
solution: Solution = {v: None for v in var_names}
|
||||||
category_offset: Dict[Hashable, int] = {cat: 0 for cat in x.keys()}
|
category_offset: Dict[str, int] = {cat: 0 for cat in x.keys()}
|
||||||
for (i, var_name) in enumerate(var_names):
|
for (i, var_name) in enumerate(var_names):
|
||||||
category = var_categories[i]
|
category = var_categories[i]
|
||||||
if category not in category_offset:
|
if category not in category_offset:
|
||||||
@@ -194,7 +186,7 @@ class PrimalSolutionComponent(Component):
|
|||||||
self,
|
self,
|
||||||
_: Optional[Instance],
|
_: Optional[Instance],
|
||||||
sample: Sample,
|
sample: Sample,
|
||||||
) -> Dict[Hashable, Dict[str, float]]:
|
) -> Dict[str, Dict[str, float]]:
|
||||||
mip_var_values = sample.get("mip_var_values")
|
mip_var_values = sample.get("mip_var_values")
|
||||||
var_names = sample.get("var_names")
|
var_names = sample.get("var_names")
|
||||||
assert mip_var_values is not None
|
assert mip_var_values is not None
|
||||||
@@ -221,13 +213,13 @@ class PrimalSolutionComponent(Component):
|
|||||||
pred_one_negative = vars_all - pred_one_positive
|
pred_one_negative = vars_all - pred_one_positive
|
||||||
pred_zero_negative = vars_all - pred_zero_positive
|
pred_zero_negative = vars_all - pred_zero_positive
|
||||||
return {
|
return {
|
||||||
0: classifier_evaluation_dict(
|
"0": classifier_evaluation_dict(
|
||||||
tp=len(pred_zero_positive & vars_zero),
|
tp=len(pred_zero_positive & vars_zero),
|
||||||
tn=len(pred_zero_negative & vars_one),
|
tn=len(pred_zero_negative & vars_one),
|
||||||
fp=len(pred_zero_positive & vars_one),
|
fp=len(pred_zero_positive & vars_one),
|
||||||
fn=len(pred_zero_negative & vars_zero),
|
fn=len(pred_zero_negative & vars_zero),
|
||||||
),
|
),
|
||||||
1: classifier_evaluation_dict(
|
"1": classifier_evaluation_dict(
|
||||||
tp=len(pred_one_positive & vars_one),
|
tp=len(pred_one_positive & vars_one),
|
||||||
tn=len(pred_one_negative & vars_zero),
|
tn=len(pred_one_negative & vars_zero),
|
||||||
fp=len(pred_one_positive & vars_zero),
|
fp=len(pred_one_positive & vars_zero),
|
||||||
@@ -238,8 +230,8 @@ class PrimalSolutionComponent(Component):
|
|||||||
@overrides
|
@overrides
|
||||||
def fit_xy(
|
def fit_xy(
|
||||||
self,
|
self,
|
||||||
x: Dict[Hashable, np.ndarray],
|
x: Dict[str, np.ndarray],
|
||||||
y: Dict[Hashable, np.ndarray],
|
y: Dict[str, np.ndarray],
|
||||||
) -> None:
|
) -> None:
|
||||||
for category in x.keys():
|
for category in x.keys():
|
||||||
clf = self.classifier_prototype.clone()
|
clf = self.classifier_prototype.clone()
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
# Released under the modified BSD license. See COPYING.md for more details.
|
# Released under the modified BSD license. See COPYING.md for more details.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, Tuple, List, Hashable, Any, TYPE_CHECKING, Set, Optional
|
from typing import Dict, Tuple, List, Any, TYPE_CHECKING, Set, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from overrides import overrides
|
from overrides import overrides
|
||||||
@@ -44,11 +44,11 @@ class StaticLazyConstraintsComponent(Component):
|
|||||||
assert isinstance(classifier, Classifier)
|
assert isinstance(classifier, Classifier)
|
||||||
self.classifier_prototype: Classifier = classifier
|
self.classifier_prototype: Classifier = classifier
|
||||||
self.threshold_prototype: Threshold = threshold
|
self.threshold_prototype: Threshold = threshold
|
||||||
self.classifiers: Dict[Hashable, Classifier] = {}
|
self.classifiers: Dict[str, Classifier] = {}
|
||||||
self.thresholds: Dict[Hashable, Threshold] = {}
|
self.thresholds: Dict[str, Threshold] = {}
|
||||||
self.pool: Constraints = Constraints()
|
self.pool: Constraints = Constraints()
|
||||||
self.violation_tolerance: float = violation_tolerance
|
self.violation_tolerance: float = violation_tolerance
|
||||||
self.enforced_cids: Set[Hashable] = set()
|
self.enforced_cids: Set[str] = set()
|
||||||
self.n_restored: int = 0
|
self.n_restored: int = 0
|
||||||
self.n_iterations: int = 0
|
self.n_iterations: int = 0
|
||||||
|
|
||||||
@@ -105,8 +105,8 @@ class StaticLazyConstraintsComponent(Component):
|
|||||||
@overrides
|
@overrides
|
||||||
def fit_xy(
|
def fit_xy(
|
||||||
self,
|
self,
|
||||||
x: Dict[Hashable, np.ndarray],
|
x: Dict[str, np.ndarray],
|
||||||
y: Dict[Hashable, np.ndarray],
|
y: Dict[str, np.ndarray],
|
||||||
) -> None:
|
) -> None:
|
||||||
for c in y.keys():
|
for c in y.keys():
|
||||||
assert c in x
|
assert c in x
|
||||||
@@ -136,9 +136,9 @@ class StaticLazyConstraintsComponent(Component):
|
|||||||
) -> None:
|
) -> None:
|
||||||
self._check_and_add(solver)
|
self._check_and_add(solver)
|
||||||
|
|
||||||
def sample_predict(self, sample: Sample) -> List[Hashable]:
|
def sample_predict(self, sample: Sample) -> List[str]:
|
||||||
x, y, cids = self._sample_xy_with_cids(sample)
|
x, y, cids = self._sample_xy_with_cids(sample)
|
||||||
enforced_cids: List[Hashable] = []
|
enforced_cids: List[str] = []
|
||||||
for category in x.keys():
|
for category in x.keys():
|
||||||
if category not in self.classifiers:
|
if category not in self.classifiers:
|
||||||
continue
|
continue
|
||||||
@@ -156,7 +156,7 @@ class StaticLazyConstraintsComponent(Component):
|
|||||||
self,
|
self,
|
||||||
_: Optional[Instance],
|
_: Optional[Instance],
|
||||||
sample: Sample,
|
sample: Sample,
|
||||||
) -> Tuple[Dict[Hashable, List[List[float]]], Dict[Hashable, List[List[float]]]]:
|
) -> Tuple[Dict[str, List[List[float]]], Dict[str, List[List[float]]]]:
|
||||||
x, y, __ = self._sample_xy_with_cids(sample)
|
x, y, __ = self._sample_xy_with_cids(sample)
|
||||||
return x, y
|
return x, y
|
||||||
|
|
||||||
@@ -197,13 +197,13 @@ class StaticLazyConstraintsComponent(Component):
|
|||||||
def _sample_xy_with_cids(
|
def _sample_xy_with_cids(
|
||||||
self, sample: Sample
|
self, sample: Sample
|
||||||
) -> Tuple[
|
) -> Tuple[
|
||||||
Dict[Hashable, List[List[float]]],
|
Dict[str, List[List[float]]],
|
||||||
Dict[Hashable, List[List[float]]],
|
Dict[str, List[List[float]]],
|
||||||
Dict[Hashable, List[str]],
|
Dict[str, List[str]],
|
||||||
]:
|
]:
|
||||||
x: Dict[Hashable, List[List[float]]] = {}
|
x: Dict[str, List[List[float]]] = {}
|
||||||
y: Dict[Hashable, List[List[float]]] = {}
|
y: Dict[str, List[List[float]]] = {}
|
||||||
cids: Dict[Hashable, List[str]] = {}
|
cids: Dict[str, List[str]] = {}
|
||||||
instance_features = sample.get("instance_features_user")
|
instance_features = sample.get("instance_features_user")
|
||||||
constr_features = sample.get("lp_constr_features")
|
constr_features = sample.get("lp_constr_features")
|
||||||
constr_names = sample.get("constr_names")
|
constr_names = sample.get("constr_names")
|
||||||
|
|||||||
@@ -5,7 +5,7 @@
|
|||||||
import collections
|
import collections
|
||||||
import numbers
|
import numbers
|
||||||
from math import log, isfinite
|
from math import log, isfinite
|
||||||
from typing import TYPE_CHECKING, Dict, Optional, List, Hashable, Any
|
from typing import TYPE_CHECKING, Dict, Optional, List, Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -142,7 +142,7 @@ class FeaturesExtractor:
|
|||||||
instance: "Instance",
|
instance: "Instance",
|
||||||
sample: Sample,
|
sample: Sample,
|
||||||
) -> None:
|
) -> None:
|
||||||
categories: List[Optional[Hashable]] = []
|
categories: List[Optional[str]] = []
|
||||||
user_features: List[Optional[List[float]]] = []
|
user_features: List[Optional[List[float]]] = []
|
||||||
var_features_dict = instance.get_variable_features()
|
var_features_dict = instance.get_variable_features()
|
||||||
var_categories_dict = instance.get_variable_categories()
|
var_categories_dict = instance.get_variable_categories()
|
||||||
@@ -153,9 +153,9 @@ class FeaturesExtractor:
|
|||||||
user_features.append(None)
|
user_features.append(None)
|
||||||
categories.append(None)
|
categories.append(None)
|
||||||
continue
|
continue
|
||||||
category: Hashable = var_categories_dict[var_name]
|
category: str = var_categories_dict[var_name]
|
||||||
assert isinstance(category, collections.Hashable), (
|
assert isinstance(category, str), (
|
||||||
f"Variable category must be be hashable. "
|
f"Variable category must be a string. "
|
||||||
f"Found {type(category).__name__} instead for var={var_name}."
|
f"Found {type(category).__name__} instead for var={var_name}."
|
||||||
)
|
)
|
||||||
categories.append(category)
|
categories.append(category)
|
||||||
@@ -187,7 +187,7 @@ class FeaturesExtractor:
|
|||||||
) -> None:
|
) -> None:
|
||||||
has_static_lazy = instance.has_static_lazy_constraints()
|
has_static_lazy = instance.has_static_lazy_constraints()
|
||||||
user_features: List[Optional[List[float]]] = []
|
user_features: List[Optional[List[float]]] = []
|
||||||
categories: List[Optional[Hashable]] = []
|
categories: List[Optional[str]] = []
|
||||||
lazy: List[bool] = []
|
lazy: List[bool] = []
|
||||||
constr_categories_dict = instance.get_constraint_categories()
|
constr_categories_dict = instance.get_constraint_categories()
|
||||||
constr_features_dict = instance.get_constraint_features()
|
constr_features_dict = instance.get_constraint_features()
|
||||||
@@ -195,15 +195,15 @@ class FeaturesExtractor:
|
|||||||
assert constr_names is not None
|
assert constr_names is not None
|
||||||
|
|
||||||
for (cidx, cname) in enumerate(constr_names):
|
for (cidx, cname) in enumerate(constr_names):
|
||||||
category: Optional[Hashable] = cname
|
category: Optional[str] = cname
|
||||||
if cname in constr_categories_dict:
|
if cname in constr_categories_dict:
|
||||||
category = constr_categories_dict[cname]
|
category = constr_categories_dict[cname]
|
||||||
if category is None:
|
if category is None:
|
||||||
user_features.append(None)
|
user_features.append(None)
|
||||||
categories.append(None)
|
categories.append(None)
|
||||||
continue
|
continue
|
||||||
assert isinstance(category, collections.Hashable), (
|
assert isinstance(category, str), (
|
||||||
f"Constraint category must be hashable. "
|
f"Constraint category must be a string. "
|
||||||
f"Found {type(category).__name__} instead for cname={cname}.",
|
f"Found {type(category).__name__} instead for cname={cname}.",
|
||||||
)
|
)
|
||||||
categories.append(category)
|
categories.append(category)
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, List, Hashable, TYPE_CHECKING, Dict
|
from typing import Any, List, TYPE_CHECKING, Dict
|
||||||
|
|
||||||
from miplearn.features.sample import Sample
|
from miplearn.features.sample import Sample
|
||||||
|
|
||||||
@@ -83,7 +83,7 @@ class Instance(ABC):
|
|||||||
"""
|
"""
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def get_variable_categories(self) -> Dict[str, Hashable]:
|
def get_variable_categories(self) -> Dict[str, str]:
|
||||||
"""
|
"""
|
||||||
Returns a dictionary mapping the name of each variable to its category.
|
Returns a dictionary mapping the name of each variable to its category.
|
||||||
|
|
||||||
@@ -91,7 +91,6 @@ class Instance(ABC):
|
|||||||
internal ML model to predict the values of both variables. If a variable is not
|
internal ML model to predict the values of both variables. If a variable is not
|
||||||
listed in the dictionary, ML models will ignore the variable.
|
listed in the dictionary, ML models will ignore the variable.
|
||||||
|
|
||||||
A category can be any hashable type, such as strings, numbers or tuples.
|
|
||||||
By default, returns {}.
|
By default, returns {}.
|
||||||
"""
|
"""
|
||||||
return {}
|
return {}
|
||||||
@@ -99,7 +98,7 @@ class Instance(ABC):
|
|||||||
def get_constraint_features(self) -> Dict[str, List[float]]:
|
def get_constraint_features(self) -> Dict[str, List[float]]:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def get_constraint_categories(self) -> Dict[str, Hashable]:
|
def get_constraint_categories(self) -> Dict[str, str]:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def has_static_lazy_constraints(self) -> bool:
|
def has_static_lazy_constraints(self) -> bool:
|
||||||
@@ -115,7 +114,7 @@ class Instance(ABC):
|
|||||||
self,
|
self,
|
||||||
solver: "InternalSolver",
|
solver: "InternalSolver",
|
||||||
model: Any,
|
model: Any,
|
||||||
) -> List[Hashable]:
|
) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Returns lazy constraint violations found for the current solution.
|
Returns lazy constraint violations found for the current solution.
|
||||||
|
|
||||||
@@ -125,10 +124,10 @@ class Instance(ABC):
|
|||||||
resolve the problem. The process repeats until no further lazy constraint
|
resolve the problem. The process repeats until no further lazy constraint
|
||||||
violations are found.
|
violations are found.
|
||||||
|
|
||||||
Each "violation" is simply a string, a tuple or any other hashable type which
|
Each "violation" is simply a string which allows the instance to identify
|
||||||
allows the instance to identify unambiguously which lazy constraint should be
|
unambiguously which lazy constraint should be generated. In the Traveling
|
||||||
generated. In the Traveling Salesman Problem, for example, a subtour
|
Salesman Problem, for example, a subtour violation could be a string
|
||||||
violation could be a frozen set containing the cities in the subtour.
|
containing the cities in the subtour.
|
||||||
|
|
||||||
The current solution can be queried with `solver.get_solution()`. If the solver
|
The current solution can be queried with `solver.get_solution()`. If the solver
|
||||||
is configured to use lazy callbacks, this solution may be non-integer.
|
is configured to use lazy callbacks, this solution may be non-integer.
|
||||||
@@ -141,7 +140,7 @@ class Instance(ABC):
|
|||||||
self,
|
self,
|
||||||
solver: "InternalSolver",
|
solver: "InternalSolver",
|
||||||
model: Any,
|
model: Any,
|
||||||
violation: Hashable,
|
violation: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Adds constraints to the model to ensure that the given violation is fixed.
|
Adds constraints to the model to ensure that the given violation is fixed.
|
||||||
@@ -167,14 +166,14 @@ class Instance(ABC):
|
|||||||
def has_user_cuts(self) -> bool:
|
def has_user_cuts(self) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def find_violated_user_cuts(self, model: Any) -> List[Hashable]:
|
def find_violated_user_cuts(self, model: Any) -> List[str]:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def enforce_user_cut(
|
def enforce_user_cut(
|
||||||
self,
|
self,
|
||||||
solver: "InternalSolver",
|
solver: "InternalSolver",
|
||||||
model: Any,
|
model: Any,
|
||||||
violation: Hashable,
|
violation: str,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import gc
|
|||||||
import gzip
|
import gzip
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
from typing import Optional, Any, List, Hashable, cast, IO, TYPE_CHECKING, Dict
|
from typing import Optional, Any, List, cast, IO, TYPE_CHECKING, Dict
|
||||||
|
|
||||||
from overrides import overrides
|
from overrides import overrides
|
||||||
|
|
||||||
@@ -52,7 +52,7 @@ class PickleGzInstance(Instance):
|
|||||||
return self.instance.get_variable_features()
|
return self.instance.get_variable_features()
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def get_variable_categories(self) -> Dict[str, Hashable]:
|
def get_variable_categories(self) -> Dict[str, str]:
|
||||||
assert self.instance is not None
|
assert self.instance is not None
|
||||||
return self.instance.get_variable_categories()
|
return self.instance.get_variable_categories()
|
||||||
|
|
||||||
@@ -62,7 +62,7 @@ class PickleGzInstance(Instance):
|
|||||||
return self.instance.get_constraint_features()
|
return self.instance.get_constraint_features()
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def get_constraint_categories(self) -> Dict[str, Hashable]:
|
def get_constraint_categories(self) -> Dict[str, str]:
|
||||||
assert self.instance is not None
|
assert self.instance is not None
|
||||||
return self.instance.get_constraint_categories()
|
return self.instance.get_constraint_categories()
|
||||||
|
|
||||||
@@ -86,7 +86,7 @@ class PickleGzInstance(Instance):
|
|||||||
self,
|
self,
|
||||||
solver: "InternalSolver",
|
solver: "InternalSolver",
|
||||||
model: Any,
|
model: Any,
|
||||||
) -> List[Hashable]:
|
) -> List[str]:
|
||||||
assert self.instance is not None
|
assert self.instance is not None
|
||||||
return self.instance.find_violated_lazy_constraints(solver, model)
|
return self.instance.find_violated_lazy_constraints(solver, model)
|
||||||
|
|
||||||
@@ -95,13 +95,13 @@ class PickleGzInstance(Instance):
|
|||||||
self,
|
self,
|
||||||
solver: "InternalSolver",
|
solver: "InternalSolver",
|
||||||
model: Any,
|
model: Any,
|
||||||
violation: Hashable,
|
violation: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert self.instance is not None
|
assert self.instance is not None
|
||||||
self.instance.enforce_lazy_constraint(solver, model, violation)
|
self.instance.enforce_lazy_constraint(solver, model, violation)
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def find_violated_user_cuts(self, model: Any) -> List[Hashable]:
|
def find_violated_user_cuts(self, model: Any) -> List[str]:
|
||||||
assert self.instance is not None
|
assert self.instance is not None
|
||||||
return self.instance.find_violated_user_cuts(model)
|
return self.instance.find_violated_user_cuts(model)
|
||||||
|
|
||||||
@@ -110,7 +110,7 @@ class PickleGzInstance(Instance):
|
|||||||
self,
|
self,
|
||||||
solver: "InternalSolver",
|
solver: "InternalSolver",
|
||||||
model: Any,
|
model: Any,
|
||||||
violation: Hashable,
|
violation: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert self.instance is not None
|
assert self.instance is not None
|
||||||
self.instance.enforce_user_cut(solver, model, violation)
|
self.instance.enforce_user_cut(solver, model, violation)
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
||||||
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
|
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
|
||||||
# Released under the modified BSD license. See COPYING.md for more details.
|
# Released under the modified BSD license. See COPYING.md for more details.
|
||||||
from typing import List, Dict, Optional, Hashable, Any
|
|
||||||
|
from typing import List, Dict, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pyomo.environ as pe
|
import pyomo.environ as pe
|
||||||
@@ -10,7 +11,6 @@ from scipy.stats import uniform, randint, rv_discrete
|
|||||||
from scipy.stats.distributions import rv_frozen
|
from scipy.stats.distributions import rv_frozen
|
||||||
|
|
||||||
from miplearn.instance.base import Instance
|
from miplearn.instance.base import Instance
|
||||||
from miplearn.types import VariableName, Category
|
|
||||||
|
|
||||||
|
|
||||||
class ChallengeA:
|
class ChallengeA:
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
||||||
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
|
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
|
||||||
# Released under the modified BSD license. See COPYING.md for more details.
|
# Released under the modified BSD license. See COPYING.md for more details.
|
||||||
from typing import List, Dict, Hashable
|
from typing import List, Dict
|
||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -12,7 +12,6 @@ from scipy.stats import uniform, randint
|
|||||||
from scipy.stats.distributions import rv_frozen
|
from scipy.stats.distributions import rv_frozen
|
||||||
|
|
||||||
from miplearn.instance.base import Instance
|
from miplearn.instance.base import Instance
|
||||||
from miplearn.types import VariableName, Category
|
|
||||||
|
|
||||||
|
|
||||||
class ChallengeA:
|
class ChallengeA:
|
||||||
@@ -85,7 +84,7 @@ class MaxWeightStableSetInstance(Instance):
|
|||||||
return features
|
return features
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def get_variable_categories(self) -> Dict[str, Hashable]:
|
def get_variable_categories(self) -> Dict[str, str]:
|
||||||
return {f"x[{v}]": "default" for v in self.nodes}
|
return {f"x[{v}]": "default" for v in self.nodes}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
||||||
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
|
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
|
||||||
# Released under the modified BSD license. See COPYING.md for more details.
|
# Released under the modified BSD license. See COPYING.md for more details.
|
||||||
from typing import List, Tuple, FrozenSet, Any, Optional, Hashable, Dict
|
from typing import List, Tuple, FrozenSet, Any, Optional, Dict
|
||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -11,10 +11,9 @@ from scipy.spatial.distance import pdist, squareform
|
|||||||
from scipy.stats import uniform, randint
|
from scipy.stats import uniform, randint
|
||||||
from scipy.stats.distributions import rv_frozen
|
from scipy.stats.distributions import rv_frozen
|
||||||
|
|
||||||
|
from miplearn.instance.base import Instance
|
||||||
from miplearn.solvers.learning import InternalSolver
|
from miplearn.solvers.learning import InternalSolver
|
||||||
from miplearn.solvers.pyomo.base import BasePyomoSolver
|
from miplearn.solvers.pyomo.base import BasePyomoSolver
|
||||||
from miplearn.instance.base import Instance
|
|
||||||
from miplearn.types import VariableName, Category
|
|
||||||
|
|
||||||
|
|
||||||
class ChallengeA:
|
class ChallengeA:
|
||||||
@@ -82,7 +81,7 @@ class TravelingSalesmanInstance(Instance):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def get_variable_categories(self) -> Dict[str, Hashable]:
|
def get_variable_categories(self) -> Dict[str, str]:
|
||||||
return {f"x[{e}]": f"x[{e}]" for e in self.edges}
|
return {f"x[{e}]": f"x[{e}]" for e in self.edges}
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import re
|
|||||||
import sys
|
import sys
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
from random import randint
|
from random import randint
|
||||||
from typing import List, Any, Dict, Optional, Hashable, Tuple, TYPE_CHECKING
|
from typing import List, Any, Dict, Optional, Tuple, TYPE_CHECKING
|
||||||
|
|
||||||
from overrides import overrides
|
from overrides import overrides
|
||||||
|
|
||||||
@@ -672,7 +672,7 @@ class GurobiTestInstanceKnapsack(PyomoTestInstanceKnapsack):
|
|||||||
self,
|
self,
|
||||||
solver: InternalSolver,
|
solver: InternalSolver,
|
||||||
model: Any,
|
model: Any,
|
||||||
violation: Hashable,
|
violation: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
x0 = model.getVarByName("x[0]")
|
x0 = model.getVarByName("x[0]")
|
||||||
model.cbLazy(x0 <= 0)
|
model.cbLazy(x0 <= 0)
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import logging
|
|||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
from typing import Any, List, Dict, Optional, Tuple, Hashable
|
from typing import Any, List, Dict, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pyomo
|
import pyomo
|
||||||
@@ -639,5 +639,5 @@ class PyomoTestInstanceKnapsack(Instance):
|
|||||||
}
|
}
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def get_variable_categories(self) -> Dict[str, Hashable]:
|
def get_variable_categories(self) -> Dict[str, str]:
|
||||||
return {f"x[{i}]": "default" for i in range(len(self.weights))}
|
return {f"x[{i}]": "default" for i in range(len(self.weights))}
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
|
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
|
||||||
# Released under the modified BSD license. See COPYING.md for more details.
|
# Released under the modified BSD license. See COPYING.md for more details.
|
||||||
|
|
||||||
from typing import Optional, Dict, Callable, Any, Union, TYPE_CHECKING, Hashable
|
from typing import Optional, Dict, Callable, Any, Union, TYPE_CHECKING
|
||||||
|
|
||||||
from mypy_extensions import TypedDict
|
from mypy_extensions import TypedDict
|
||||||
|
|
||||||
@@ -10,7 +10,7 @@ if TYPE_CHECKING:
|
|||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
from miplearn.solvers.learning import InternalSolver
|
from miplearn.solvers.learning import InternalSolver
|
||||||
|
|
||||||
Category = Hashable
|
Category = str
|
||||||
IterationCallback = Callable[[], bool]
|
IterationCallback = Callable[[], bool]
|
||||||
LazyCallback = Callable[[Any, Any], None]
|
LazyCallback = Callable[[Any, Any], None]
|
||||||
SolverParams = Dict[str, Any]
|
SolverParams = Dict[str, Any]
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
# Released under the modified BSD license. See COPYING.md for more details.
|
# Released under the modified BSD license. See COPYING.md for more details.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, FrozenSet, Hashable, List
|
from typing import Any, FrozenSet, List
|
||||||
|
|
||||||
import gurobipy as gp
|
import gurobipy as gp
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
@@ -40,13 +40,13 @@ class GurobiStableSetProblem(Instance):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def find_violated_user_cuts(self, model: Any) -> List[FrozenSet]:
|
def find_violated_user_cuts(self, model: Any) -> List[str]:
|
||||||
assert isinstance(model, gp.Model)
|
assert isinstance(model, gp.Model)
|
||||||
vals = model.cbGetNodeRel(model.getVars())
|
vals = model.cbGetNodeRel(model.getVars())
|
||||||
violations = []
|
violations = []
|
||||||
for clique in nx.find_cliques(self.graph):
|
for clique in nx.find_cliques(self.graph):
|
||||||
if sum(vals[i] for i in clique) > 1:
|
if sum(vals[i] for i in clique) > 1:
|
||||||
violations += [frozenset(clique)]
|
violations.append(",".join([str(i) for i in clique]))
|
||||||
return violations
|
return violations
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
@@ -54,11 +54,11 @@ class GurobiStableSetProblem(Instance):
|
|||||||
self,
|
self,
|
||||||
solver: InternalSolver,
|
solver: InternalSolver,
|
||||||
model: Any,
|
model: Any,
|
||||||
cid: Hashable,
|
cid: str,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
assert isinstance(cid, FrozenSet)
|
clique = [int(i) for i in cid.split(",")]
|
||||||
x = model.getVars()
|
x = model.getVars()
|
||||||
model.addConstr(gp.quicksum([x[i] for i in cid]) <= 1)
|
model.addConstr(gp.quicksum([x[i] for i in clique]) <= 1)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
||||||
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
|
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
|
||||||
# Released under the modified BSD license. See COPYING.md for more details.
|
# Released under the modified BSD license. See COPYING.md for more details.
|
||||||
from typing import Hashable, Dict
|
from typing import Dict
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -44,11 +44,11 @@ def test_sample_xy(sample: Sample) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_fit_xy() -> None:
|
def test_fit_xy() -> None:
|
||||||
x: Dict[Hashable, np.ndarray] = {
|
x: Dict[str, np.ndarray] = {
|
||||||
"Lower bound": np.array([[0.0, 0.0], [1.0, 2.0]]),
|
"Lower bound": np.array([[0.0, 0.0], [1.0, 2.0]]),
|
||||||
"Upper bound": np.array([[0.0, 0.0], [1.0, 2.0]]),
|
"Upper bound": np.array([[0.0, 0.0], [1.0, 2.0]]),
|
||||||
}
|
}
|
||||||
y: Dict[Hashable, np.ndarray] = {
|
y: Dict[str, np.ndarray] = {
|
||||||
"Lower bound": np.array([[100.0]]),
|
"Lower bound": np.array([[100.0]]),
|
||||||
"Upper bound": np.array([[200.0]]),
|
"Upper bound": np.array([[200.0]]),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -121,8 +121,8 @@ def test_evaluate(sample: Sample) -> None:
|
|||||||
assert_equals(
|
assert_equals(
|
||||||
ev,
|
ev,
|
||||||
{
|
{
|
||||||
0: classifier_evaluation_dict(tp=0, fp=1, tn=1, fn=2),
|
"0": classifier_evaluation_dict(tp=0, fp=1, tn=1, fn=2),
|
||||||
1: classifier_evaluation_dict(tp=1, fp=1, tn=1, fn=1),
|
"1": classifier_evaluation_dict(tp=1, fp=1, tn=1, fn=1),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
||||||
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
|
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
|
||||||
# Released under the modified BSD license. See COPYING.md for more details.
|
# Released under the modified BSD license. See COPYING.md for more details.
|
||||||
from typing import Dict, cast, Hashable
|
from typing import Dict, cast
|
||||||
from unittest.mock import Mock, call
|
from unittest.mock import Mock, call
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -175,14 +175,14 @@ def test_sample_predict(sample: Sample) -> None:
|
|||||||
|
|
||||||
def test_fit_xy() -> None:
|
def test_fit_xy() -> None:
|
||||||
x = cast(
|
x = cast(
|
||||||
Dict[Hashable, np.ndarray],
|
Dict[str, np.ndarray],
|
||||||
{
|
{
|
||||||
"type-a": np.array([[1.0, 1.0], [1.0, 2.0], [1.0, 3.0]]),
|
"type-a": np.array([[1.0, 1.0], [1.0, 2.0], [1.0, 3.0]]),
|
||||||
"type-b": np.array([[1.0, 4.0, 0.0]]),
|
"type-b": np.array([[1.0, 4.0, 0.0]]),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
y = cast(
|
y = cast(
|
||||||
Dict[Hashable, np.ndarray],
|
Dict[str, np.ndarray],
|
||||||
{
|
{
|
||||||
"type-a": np.array([[False, True], [False, True], [True, False]]),
|
"type-a": np.array([[False, True], [False, True], [True, False]]),
|
||||||
"type-b": np.array([[False, True]]),
|
"type-b": np.array([[False, True]]),
|
||||||
|
|||||||
Reference in New Issue
Block a user