mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Replace Hashable by str
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
from typing import Any, List, TYPE_CHECKING, Tuple, Dict, Hashable, Optional
|
||||
from typing import Any, List, TYPE_CHECKING, Tuple, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
from p_tqdm import p_umap
|
||||
@@ -101,8 +101,8 @@ class Component:
|
||||
|
||||
def fit_xy(
|
||||
self,
|
||||
x: Dict[Hashable, np.ndarray],
|
||||
y: Dict[Hashable, np.ndarray],
|
||||
x: Dict[str, np.ndarray],
|
||||
y: Dict[str, np.ndarray],
|
||||
) -> None:
|
||||
"""
|
||||
Given two dictionaries x and y, mapping the name of the category to matrices
|
||||
@@ -152,7 +152,7 @@ class Component:
|
||||
self,
|
||||
instance: Optional[Instance],
|
||||
sample: Sample,
|
||||
) -> Dict[Hashable, Dict[str, float]]:
|
||||
) -> Dict[str, Dict[str, float]]:
|
||||
return {}
|
||||
|
||||
def sample_xy(
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
import logging
|
||||
from typing import Dict, Hashable, List, Tuple, Optional, Any, Set
|
||||
from typing import Dict, List, Tuple, Optional, Any, Set
|
||||
|
||||
import numpy as np
|
||||
from overrides import overrides
|
||||
@@ -32,8 +32,8 @@ class DynamicConstraintsComponent(Component):
|
||||
assert isinstance(classifier, Classifier)
|
||||
self.threshold_prototype: Threshold = threshold
|
||||
self.classifier_prototype: Classifier = classifier
|
||||
self.classifiers: Dict[Hashable, Classifier] = {}
|
||||
self.thresholds: Dict[Hashable, Threshold] = {}
|
||||
self.classifiers: Dict[str, Classifier] = {}
|
||||
self.thresholds: Dict[str, Threshold] = {}
|
||||
self.known_cids: List[str] = []
|
||||
self.attr = attr
|
||||
|
||||
@@ -42,14 +42,14 @@ class DynamicConstraintsComponent(Component):
|
||||
instance: Optional[Instance],
|
||||
sample: Sample,
|
||||
) -> Tuple[
|
||||
Dict[Hashable, List[List[float]]],
|
||||
Dict[Hashable, List[List[bool]]],
|
||||
Dict[Hashable, List[str]],
|
||||
Dict[str, List[List[float]]],
|
||||
Dict[str, List[List[bool]]],
|
||||
Dict[str, List[str]],
|
||||
]:
|
||||
assert instance is not None
|
||||
x: Dict[Hashable, List[List[float]]] = {}
|
||||
y: Dict[Hashable, List[List[bool]]] = {}
|
||||
cids: Dict[Hashable, List[str]] = {}
|
||||
x: Dict[str, List[List[float]]] = {}
|
||||
y: Dict[str, List[List[bool]]] = {}
|
||||
cids: Dict[str, List[str]] = {}
|
||||
constr_categories_dict = instance.get_constraint_categories()
|
||||
constr_features_dict = instance.get_constraint_features()
|
||||
instance_features = sample.get("instance_features_user")
|
||||
@@ -111,8 +111,8 @@ class DynamicConstraintsComponent(Component):
|
||||
self,
|
||||
instance: Instance,
|
||||
sample: Sample,
|
||||
) -> List[Hashable]:
|
||||
pred: List[Hashable] = []
|
||||
) -> List[str]:
|
||||
pred: List[str] = []
|
||||
if len(self.known_cids) == 0:
|
||||
logger.info("Classifiers not fitted. Skipping.")
|
||||
return pred
|
||||
@@ -137,8 +137,8 @@ class DynamicConstraintsComponent(Component):
|
||||
@overrides
|
||||
def fit_xy(
|
||||
self,
|
||||
x: Dict[Hashable, np.ndarray],
|
||||
y: Dict[Hashable, np.ndarray],
|
||||
x: Dict[str, np.ndarray],
|
||||
y: Dict[str, np.ndarray],
|
||||
) -> None:
|
||||
for category in x.keys():
|
||||
self.classifiers[category] = self.classifier_prototype.clone()
|
||||
@@ -153,14 +153,14 @@ class DynamicConstraintsComponent(Component):
|
||||
self,
|
||||
instance: Instance,
|
||||
sample: Sample,
|
||||
) -> Dict[Hashable, Dict[str, float]]:
|
||||
) -> Dict[str, Dict[str, float]]:
|
||||
actual = sample.get(self.attr)
|
||||
assert actual is not None
|
||||
pred = set(self.sample_predict(instance, sample))
|
||||
tp: Dict[Hashable, int] = {}
|
||||
tn: Dict[Hashable, int] = {}
|
||||
fp: Dict[Hashable, int] = {}
|
||||
fn: Dict[Hashable, int] = {}
|
||||
tp: Dict[str, int] = {}
|
||||
tn: Dict[str, int] = {}
|
||||
fp: Dict[str, int] = {}
|
||||
fn: Dict[str, int] = {}
|
||||
constr_categories_dict = instance.get_constraint_categories()
|
||||
for cid in self.known_cids:
|
||||
if cid not in constr_categories_dict:
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, TYPE_CHECKING, Hashable, Tuple, Any, Optional, Set
|
||||
from typing import Dict, List, TYPE_CHECKING, Tuple, Any, Optional, Set
|
||||
|
||||
import numpy as np
|
||||
from overrides import overrides
|
||||
@@ -41,11 +41,11 @@ class DynamicLazyConstraintsComponent(Component):
|
||||
self.classifiers = self.dynamic.classifiers
|
||||
self.thresholds = self.dynamic.thresholds
|
||||
self.known_cids = self.dynamic.known_cids
|
||||
self.lazy_enforced: Set[Hashable] = set()
|
||||
self.lazy_enforced: Set[str] = set()
|
||||
|
||||
@staticmethod
|
||||
def enforce(
|
||||
cids: List[Hashable],
|
||||
cids: List[str],
|
||||
instance: Instance,
|
||||
model: Any,
|
||||
solver: "LearningSolver",
|
||||
@@ -117,7 +117,7 @@ class DynamicLazyConstraintsComponent(Component):
|
||||
self,
|
||||
instance: Instance,
|
||||
sample: Sample,
|
||||
) -> List[Hashable]:
|
||||
) -> List[str]:
|
||||
return self.dynamic.sample_predict(instance, sample)
|
||||
|
||||
@overrides
|
||||
@@ -127,8 +127,8 @@ class DynamicLazyConstraintsComponent(Component):
|
||||
@overrides
|
||||
def fit_xy(
|
||||
self,
|
||||
x: Dict[Hashable, np.ndarray],
|
||||
y: Dict[Hashable, np.ndarray],
|
||||
x: Dict[str, np.ndarray],
|
||||
y: Dict[str, np.ndarray],
|
||||
) -> None:
|
||||
self.dynamic.fit_xy(x, y)
|
||||
|
||||
@@ -137,5 +137,5 @@ class DynamicLazyConstraintsComponent(Component):
|
||||
self,
|
||||
instance: Instance,
|
||||
sample: Sample,
|
||||
) -> Dict[Hashable, Dict[str, float]]:
|
||||
) -> Dict[str, Dict[str, float]]:
|
||||
return self.dynamic.sample_evaluate(instance, sample)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
import logging
|
||||
from typing import Any, TYPE_CHECKING, Hashable, Set, Tuple, Dict, List, Optional
|
||||
from typing import Any, TYPE_CHECKING, Set, Tuple, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
from overrides import overrides
|
||||
@@ -34,7 +34,7 @@ class UserCutsComponent(Component):
|
||||
threshold=threshold,
|
||||
attr="user_cuts_enforced",
|
||||
)
|
||||
self.enforced: Set[Hashable] = set()
|
||||
self.enforced: Set[str] = set()
|
||||
self.n_added_in_callback = 0
|
||||
|
||||
@overrides
|
||||
@@ -71,7 +71,7 @@ class UserCutsComponent(Component):
|
||||
for cid in cids:
|
||||
if cid in self.enforced:
|
||||
continue
|
||||
assert isinstance(cid, Hashable)
|
||||
assert isinstance(cid, str)
|
||||
instance.enforce_user_cut(solver.internal_solver, model, cid)
|
||||
self.enforced.add(cid)
|
||||
self.n_added_in_callback += 1
|
||||
@@ -110,7 +110,7 @@ class UserCutsComponent(Component):
|
||||
self,
|
||||
instance: "Instance",
|
||||
sample: Sample,
|
||||
) -> List[Hashable]:
|
||||
) -> List[str]:
|
||||
return self.dynamic.sample_predict(instance, sample)
|
||||
|
||||
@overrides
|
||||
@@ -120,8 +120,8 @@ class UserCutsComponent(Component):
|
||||
@overrides
|
||||
def fit_xy(
|
||||
self,
|
||||
x: Dict[Hashable, np.ndarray],
|
||||
y: Dict[Hashable, np.ndarray],
|
||||
x: Dict[str, np.ndarray],
|
||||
y: Dict[str, np.ndarray],
|
||||
) -> None:
|
||||
self.dynamic.fit_xy(x, y)
|
||||
|
||||
@@ -130,5 +130,5 @@ class UserCutsComponent(Component):
|
||||
self,
|
||||
instance: "Instance",
|
||||
sample: Sample,
|
||||
) -> Dict[Hashable, Dict[str, float]]:
|
||||
) -> Dict[str, Dict[str, float]]:
|
||||
return self.dynamic.sample_evaluate(instance, sample)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
import logging
|
||||
from typing import List, Dict, Any, TYPE_CHECKING, Tuple, Hashable, Optional
|
||||
from typing import List, Dict, Any, TYPE_CHECKING, Tuple, Optional
|
||||
|
||||
import numpy as np
|
||||
from overrides import overrides
|
||||
@@ -53,8 +53,8 @@ class ObjectiveValueComponent(Component):
|
||||
@overrides
|
||||
def fit_xy(
|
||||
self,
|
||||
x: Dict[Hashable, np.ndarray],
|
||||
y: Dict[Hashable, np.ndarray],
|
||||
x: Dict[str, np.ndarray],
|
||||
y: Dict[str, np.ndarray],
|
||||
) -> None:
|
||||
for c in ["Upper bound", "Lower bound"]:
|
||||
if c in y:
|
||||
@@ -76,20 +76,20 @@ class ObjectiveValueComponent(Component):
|
||||
self,
|
||||
_: Optional[Instance],
|
||||
sample: Sample,
|
||||
) -> Tuple[Dict[Hashable, List[List[float]]], Dict[Hashable, List[List[float]]]]:
|
||||
) -> Tuple[Dict[str, List[List[float]]], Dict[str, List[List[float]]]]:
|
||||
lp_instance_features = sample.get("lp_instance_features")
|
||||
if lp_instance_features is None:
|
||||
lp_instance_features = sample.get("instance_features_user")
|
||||
assert lp_instance_features is not None
|
||||
|
||||
# Features
|
||||
x: Dict[Hashable, List[List[float]]] = {
|
||||
x: Dict[str, List[List[float]]] = {
|
||||
"Upper bound": [lp_instance_features],
|
||||
"Lower bound": [lp_instance_features],
|
||||
}
|
||||
|
||||
# Labels
|
||||
y: Dict[Hashable, List[List[float]]] = {}
|
||||
y: Dict[str, List[List[float]]] = {}
|
||||
mip_lower_bound = sample.get("mip_lower_bound")
|
||||
mip_upper_bound = sample.get("mip_upper_bound")
|
||||
if mip_lower_bound is not None:
|
||||
@@ -104,7 +104,7 @@ class ObjectiveValueComponent(Component):
|
||||
self,
|
||||
instance: Instance,
|
||||
sample: Sample,
|
||||
) -> Dict[Hashable, Dict[str, float]]:
|
||||
) -> Dict[str, Dict[str, float]]:
|
||||
def compare(y_pred: float, y_actual: float) -> Dict[str, float]:
|
||||
err = np.round(abs(y_pred - y_actual), 8)
|
||||
return {
|
||||
@@ -114,7 +114,7 @@ class ObjectiveValueComponent(Component):
|
||||
"Relative error": err / y_actual,
|
||||
}
|
||||
|
||||
result: Dict[Hashable, Dict[str, float]] = {}
|
||||
result: Dict[str, Dict[str, float]] = {}
|
||||
pred = self.sample_predict(sample)
|
||||
actual_ub = sample.get("mip_upper_bound")
|
||||
actual_lb = sample.get("mip_lower_bound")
|
||||
|
||||
@@ -3,15 +3,7 @@
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
import logging
|
||||
from typing import (
|
||||
Dict,
|
||||
List,
|
||||
Hashable,
|
||||
Any,
|
||||
TYPE_CHECKING,
|
||||
Tuple,
|
||||
Optional,
|
||||
)
|
||||
from typing import Dict, List, Any, TYPE_CHECKING, Tuple, Optional
|
||||
|
||||
import numpy as np
|
||||
from overrides import overrides
|
||||
@@ -55,8 +47,8 @@ class PrimalSolutionComponent(Component):
|
||||
assert isinstance(threshold, Threshold)
|
||||
assert mode in ["exact", "heuristic"]
|
||||
self.mode = mode
|
||||
self.classifiers: Dict[Hashable, Classifier] = {}
|
||||
self.thresholds: Dict[Hashable, Threshold] = {}
|
||||
self.classifiers: Dict[str, Classifier] = {}
|
||||
self.thresholds: Dict[str, Threshold] = {}
|
||||
self.threshold_prototype = threshold
|
||||
self.classifier_prototype = classifier
|
||||
|
||||
@@ -128,7 +120,7 @@ class PrimalSolutionComponent(Component):
|
||||
|
||||
# Convert y_pred into solution
|
||||
solution: Solution = {v: None for v in var_names}
|
||||
category_offset: Dict[Hashable, int] = {cat: 0 for cat in x.keys()}
|
||||
category_offset: Dict[str, int] = {cat: 0 for cat in x.keys()}
|
||||
for (i, var_name) in enumerate(var_names):
|
||||
category = var_categories[i]
|
||||
if category not in category_offset:
|
||||
@@ -194,7 +186,7 @@ class PrimalSolutionComponent(Component):
|
||||
self,
|
||||
_: Optional[Instance],
|
||||
sample: Sample,
|
||||
) -> Dict[Hashable, Dict[str, float]]:
|
||||
) -> Dict[str, Dict[str, float]]:
|
||||
mip_var_values = sample.get("mip_var_values")
|
||||
var_names = sample.get("var_names")
|
||||
assert mip_var_values is not None
|
||||
@@ -221,13 +213,13 @@ class PrimalSolutionComponent(Component):
|
||||
pred_one_negative = vars_all - pred_one_positive
|
||||
pred_zero_negative = vars_all - pred_zero_positive
|
||||
return {
|
||||
0: classifier_evaluation_dict(
|
||||
"0": classifier_evaluation_dict(
|
||||
tp=len(pred_zero_positive & vars_zero),
|
||||
tn=len(pred_zero_negative & vars_one),
|
||||
fp=len(pred_zero_positive & vars_one),
|
||||
fn=len(pred_zero_negative & vars_zero),
|
||||
),
|
||||
1: classifier_evaluation_dict(
|
||||
"1": classifier_evaluation_dict(
|
||||
tp=len(pred_one_positive & vars_one),
|
||||
tn=len(pred_one_negative & vars_zero),
|
||||
fp=len(pred_one_positive & vars_zero),
|
||||
@@ -238,8 +230,8 @@ class PrimalSolutionComponent(Component):
|
||||
@overrides
|
||||
def fit_xy(
|
||||
self,
|
||||
x: Dict[Hashable, np.ndarray],
|
||||
y: Dict[Hashable, np.ndarray],
|
||||
x: Dict[str, np.ndarray],
|
||||
y: Dict[str, np.ndarray],
|
||||
) -> None:
|
||||
for category in x.keys():
|
||||
clf = self.classifier_prototype.clone()
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
import logging
|
||||
from typing import Dict, Tuple, List, Hashable, Any, TYPE_CHECKING, Set, Optional
|
||||
from typing import Dict, Tuple, List, Any, TYPE_CHECKING, Set, Optional
|
||||
|
||||
import numpy as np
|
||||
from overrides import overrides
|
||||
@@ -44,11 +44,11 @@ class StaticLazyConstraintsComponent(Component):
|
||||
assert isinstance(classifier, Classifier)
|
||||
self.classifier_prototype: Classifier = classifier
|
||||
self.threshold_prototype: Threshold = threshold
|
||||
self.classifiers: Dict[Hashable, Classifier] = {}
|
||||
self.thresholds: Dict[Hashable, Threshold] = {}
|
||||
self.classifiers: Dict[str, Classifier] = {}
|
||||
self.thresholds: Dict[str, Threshold] = {}
|
||||
self.pool: Constraints = Constraints()
|
||||
self.violation_tolerance: float = violation_tolerance
|
||||
self.enforced_cids: Set[Hashable] = set()
|
||||
self.enforced_cids: Set[str] = set()
|
||||
self.n_restored: int = 0
|
||||
self.n_iterations: int = 0
|
||||
|
||||
@@ -105,8 +105,8 @@ class StaticLazyConstraintsComponent(Component):
|
||||
@overrides
|
||||
def fit_xy(
|
||||
self,
|
||||
x: Dict[Hashable, np.ndarray],
|
||||
y: Dict[Hashable, np.ndarray],
|
||||
x: Dict[str, np.ndarray],
|
||||
y: Dict[str, np.ndarray],
|
||||
) -> None:
|
||||
for c in y.keys():
|
||||
assert c in x
|
||||
@@ -136,9 +136,9 @@ class StaticLazyConstraintsComponent(Component):
|
||||
) -> None:
|
||||
self._check_and_add(solver)
|
||||
|
||||
def sample_predict(self, sample: Sample) -> List[Hashable]:
|
||||
def sample_predict(self, sample: Sample) -> List[str]:
|
||||
x, y, cids = self._sample_xy_with_cids(sample)
|
||||
enforced_cids: List[Hashable] = []
|
||||
enforced_cids: List[str] = []
|
||||
for category in x.keys():
|
||||
if category not in self.classifiers:
|
||||
continue
|
||||
@@ -156,7 +156,7 @@ class StaticLazyConstraintsComponent(Component):
|
||||
self,
|
||||
_: Optional[Instance],
|
||||
sample: Sample,
|
||||
) -> Tuple[Dict[Hashable, List[List[float]]], Dict[Hashable, List[List[float]]]]:
|
||||
) -> Tuple[Dict[str, List[List[float]]], Dict[str, List[List[float]]]]:
|
||||
x, y, __ = self._sample_xy_with_cids(sample)
|
||||
return x, y
|
||||
|
||||
@@ -197,13 +197,13 @@ class StaticLazyConstraintsComponent(Component):
|
||||
def _sample_xy_with_cids(
|
||||
self, sample: Sample
|
||||
) -> Tuple[
|
||||
Dict[Hashable, List[List[float]]],
|
||||
Dict[Hashable, List[List[float]]],
|
||||
Dict[Hashable, List[str]],
|
||||
Dict[str, List[List[float]]],
|
||||
Dict[str, List[List[float]]],
|
||||
Dict[str, List[str]],
|
||||
]:
|
||||
x: Dict[Hashable, List[List[float]]] = {}
|
||||
y: Dict[Hashable, List[List[float]]] = {}
|
||||
cids: Dict[Hashable, List[str]] = {}
|
||||
x: Dict[str, List[List[float]]] = {}
|
||||
y: Dict[str, List[List[float]]] = {}
|
||||
cids: Dict[str, List[str]] = {}
|
||||
instance_features = sample.get("instance_features_user")
|
||||
constr_features = sample.get("lp_constr_features")
|
||||
constr_names = sample.get("constr_names")
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
import collections
|
||||
import numbers
|
||||
from math import log, isfinite
|
||||
from typing import TYPE_CHECKING, Dict, Optional, List, Hashable, Any
|
||||
from typing import TYPE_CHECKING, Dict, Optional, List, Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -142,7 +142,7 @@ class FeaturesExtractor:
|
||||
instance: "Instance",
|
||||
sample: Sample,
|
||||
) -> None:
|
||||
categories: List[Optional[Hashable]] = []
|
||||
categories: List[Optional[str]] = []
|
||||
user_features: List[Optional[List[float]]] = []
|
||||
var_features_dict = instance.get_variable_features()
|
||||
var_categories_dict = instance.get_variable_categories()
|
||||
@@ -153,9 +153,9 @@ class FeaturesExtractor:
|
||||
user_features.append(None)
|
||||
categories.append(None)
|
||||
continue
|
||||
category: Hashable = var_categories_dict[var_name]
|
||||
assert isinstance(category, collections.Hashable), (
|
||||
f"Variable category must be be hashable. "
|
||||
category: str = var_categories_dict[var_name]
|
||||
assert isinstance(category, str), (
|
||||
f"Variable category must be a string. "
|
||||
f"Found {type(category).__name__} instead for var={var_name}."
|
||||
)
|
||||
categories.append(category)
|
||||
@@ -187,7 +187,7 @@ class FeaturesExtractor:
|
||||
) -> None:
|
||||
has_static_lazy = instance.has_static_lazy_constraints()
|
||||
user_features: List[Optional[List[float]]] = []
|
||||
categories: List[Optional[Hashable]] = []
|
||||
categories: List[Optional[str]] = []
|
||||
lazy: List[bool] = []
|
||||
constr_categories_dict = instance.get_constraint_categories()
|
||||
constr_features_dict = instance.get_constraint_features()
|
||||
@@ -195,15 +195,15 @@ class FeaturesExtractor:
|
||||
assert constr_names is not None
|
||||
|
||||
for (cidx, cname) in enumerate(constr_names):
|
||||
category: Optional[Hashable] = cname
|
||||
category: Optional[str] = cname
|
||||
if cname in constr_categories_dict:
|
||||
category = constr_categories_dict[cname]
|
||||
if category is None:
|
||||
user_features.append(None)
|
||||
categories.append(None)
|
||||
continue
|
||||
assert isinstance(category, collections.Hashable), (
|
||||
f"Constraint category must be hashable. "
|
||||
assert isinstance(category, str), (
|
||||
f"Constraint category must be a string. "
|
||||
f"Found {type(category).__name__} instead for cname={cname}.",
|
||||
)
|
||||
categories.append(category)
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Hashable, TYPE_CHECKING, Dict
|
||||
from typing import Any, List, TYPE_CHECKING, Dict
|
||||
|
||||
from miplearn.features.sample import Sample
|
||||
|
||||
@@ -83,7 +83,7 @@ class Instance(ABC):
|
||||
"""
|
||||
return {}
|
||||
|
||||
def get_variable_categories(self) -> Dict[str, Hashable]:
|
||||
def get_variable_categories(self) -> Dict[str, str]:
|
||||
"""
|
||||
Returns a dictionary mapping the name of each variable to its category.
|
||||
|
||||
@@ -91,7 +91,6 @@ class Instance(ABC):
|
||||
internal ML model to predict the values of both variables. If a variable is not
|
||||
listed in the dictionary, ML models will ignore the variable.
|
||||
|
||||
A category can be any hashable type, such as strings, numbers or tuples.
|
||||
By default, returns {}.
|
||||
"""
|
||||
return {}
|
||||
@@ -99,7 +98,7 @@ class Instance(ABC):
|
||||
def get_constraint_features(self) -> Dict[str, List[float]]:
|
||||
return {}
|
||||
|
||||
def get_constraint_categories(self) -> Dict[str, Hashable]:
|
||||
def get_constraint_categories(self) -> Dict[str, str]:
|
||||
return {}
|
||||
|
||||
def has_static_lazy_constraints(self) -> bool:
|
||||
@@ -115,7 +114,7 @@ class Instance(ABC):
|
||||
self,
|
||||
solver: "InternalSolver",
|
||||
model: Any,
|
||||
) -> List[Hashable]:
|
||||
) -> List[str]:
|
||||
"""
|
||||
Returns lazy constraint violations found for the current solution.
|
||||
|
||||
@@ -125,10 +124,10 @@ class Instance(ABC):
|
||||
resolve the problem. The process repeats until no further lazy constraint
|
||||
violations are found.
|
||||
|
||||
Each "violation" is simply a string, a tuple or any other hashable type which
|
||||
allows the instance to identify unambiguously which lazy constraint should be
|
||||
generated. In the Traveling Salesman Problem, for example, a subtour
|
||||
violation could be a frozen set containing the cities in the subtour.
|
||||
Each "violation" is simply a string which allows the instance to identify
|
||||
unambiguously which lazy constraint should be generated. In the Traveling
|
||||
Salesman Problem, for example, a subtour violation could be a string
|
||||
containing the cities in the subtour.
|
||||
|
||||
The current solution can be queried with `solver.get_solution()`. If the solver
|
||||
is configured to use lazy callbacks, this solution may be non-integer.
|
||||
@@ -141,7 +140,7 @@ class Instance(ABC):
|
||||
self,
|
||||
solver: "InternalSolver",
|
||||
model: Any,
|
||||
violation: Hashable,
|
||||
violation: str,
|
||||
) -> None:
|
||||
"""
|
||||
Adds constraints to the model to ensure that the given violation is fixed.
|
||||
@@ -167,14 +166,14 @@ class Instance(ABC):
|
||||
def has_user_cuts(self) -> bool:
|
||||
return False
|
||||
|
||||
def find_violated_user_cuts(self, model: Any) -> List[Hashable]:
|
||||
def find_violated_user_cuts(self, model: Any) -> List[str]:
|
||||
return []
|
||||
|
||||
def enforce_user_cut(
|
||||
self,
|
||||
solver: "InternalSolver",
|
||||
model: Any,
|
||||
violation: Hashable,
|
||||
violation: str,
|
||||
) -> Any:
|
||||
return None
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ import gc
|
||||
import gzip
|
||||
import os
|
||||
import pickle
|
||||
from typing import Optional, Any, List, Hashable, cast, IO, TYPE_CHECKING, Dict
|
||||
from typing import Optional, Any, List, cast, IO, TYPE_CHECKING, Dict
|
||||
|
||||
from overrides import overrides
|
||||
|
||||
@@ -52,7 +52,7 @@ class PickleGzInstance(Instance):
|
||||
return self.instance.get_variable_features()
|
||||
|
||||
@overrides
|
||||
def get_variable_categories(self) -> Dict[str, Hashable]:
|
||||
def get_variable_categories(self) -> Dict[str, str]:
|
||||
assert self.instance is not None
|
||||
return self.instance.get_variable_categories()
|
||||
|
||||
@@ -62,7 +62,7 @@ class PickleGzInstance(Instance):
|
||||
return self.instance.get_constraint_features()
|
||||
|
||||
@overrides
|
||||
def get_constraint_categories(self) -> Dict[str, Hashable]:
|
||||
def get_constraint_categories(self) -> Dict[str, str]:
|
||||
assert self.instance is not None
|
||||
return self.instance.get_constraint_categories()
|
||||
|
||||
@@ -86,7 +86,7 @@ class PickleGzInstance(Instance):
|
||||
self,
|
||||
solver: "InternalSolver",
|
||||
model: Any,
|
||||
) -> List[Hashable]:
|
||||
) -> List[str]:
|
||||
assert self.instance is not None
|
||||
return self.instance.find_violated_lazy_constraints(solver, model)
|
||||
|
||||
@@ -95,13 +95,13 @@ class PickleGzInstance(Instance):
|
||||
self,
|
||||
solver: "InternalSolver",
|
||||
model: Any,
|
||||
violation: Hashable,
|
||||
violation: str,
|
||||
) -> None:
|
||||
assert self.instance is not None
|
||||
self.instance.enforce_lazy_constraint(solver, model, violation)
|
||||
|
||||
@overrides
|
||||
def find_violated_user_cuts(self, model: Any) -> List[Hashable]:
|
||||
def find_violated_user_cuts(self, model: Any) -> List[str]:
|
||||
assert self.instance is not None
|
||||
return self.instance.find_violated_user_cuts(model)
|
||||
|
||||
@@ -110,7 +110,7 @@ class PickleGzInstance(Instance):
|
||||
self,
|
||||
solver: "InternalSolver",
|
||||
model: Any,
|
||||
violation: Hashable,
|
||||
violation: str,
|
||||
) -> None:
|
||||
assert self.instance is not None
|
||||
self.instance.enforce_user_cut(solver, model, violation)
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
||||
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
from typing import List, Dict, Optional, Hashable, Any
|
||||
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import pyomo.environ as pe
|
||||
@@ -10,7 +11,6 @@ from scipy.stats import uniform, randint, rv_discrete
|
||||
from scipy.stats.distributions import rv_frozen
|
||||
|
||||
from miplearn.instance.base import Instance
|
||||
from miplearn.types import VariableName, Category
|
||||
|
||||
|
||||
class ChallengeA:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
||||
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
from typing import List, Dict, Hashable
|
||||
from typing import List, Dict
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
@@ -12,7 +12,6 @@ from scipy.stats import uniform, randint
|
||||
from scipy.stats.distributions import rv_frozen
|
||||
|
||||
from miplearn.instance.base import Instance
|
||||
from miplearn.types import VariableName, Category
|
||||
|
||||
|
||||
class ChallengeA:
|
||||
@@ -85,7 +84,7 @@ class MaxWeightStableSetInstance(Instance):
|
||||
return features
|
||||
|
||||
@overrides
|
||||
def get_variable_categories(self) -> Dict[str, Hashable]:
|
||||
def get_variable_categories(self) -> Dict[str, str]:
|
||||
return {f"x[{v}]": "default" for v in self.nodes}
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
||||
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
from typing import List, Tuple, FrozenSet, Any, Optional, Hashable, Dict
|
||||
from typing import List, Tuple, FrozenSet, Any, Optional, Dict
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
@@ -11,10 +11,9 @@ from scipy.spatial.distance import pdist, squareform
|
||||
from scipy.stats import uniform, randint
|
||||
from scipy.stats.distributions import rv_frozen
|
||||
|
||||
from miplearn.instance.base import Instance
|
||||
from miplearn.solvers.learning import InternalSolver
|
||||
from miplearn.solvers.pyomo.base import BasePyomoSolver
|
||||
from miplearn.instance.base import Instance
|
||||
from miplearn.types import VariableName, Category
|
||||
|
||||
|
||||
class ChallengeA:
|
||||
@@ -82,7 +81,7 @@ class TravelingSalesmanInstance(Instance):
|
||||
return model
|
||||
|
||||
@overrides
|
||||
def get_variable_categories(self) -> Dict[str, Hashable]:
|
||||
def get_variable_categories(self) -> Dict[str, str]:
|
||||
return {f"x[{e}]": f"x[{e}]" for e in self.edges}
|
||||
|
||||
@overrides
|
||||
|
||||
@@ -6,7 +6,7 @@ import re
|
||||
import sys
|
||||
from io import StringIO
|
||||
from random import randint
|
||||
from typing import List, Any, Dict, Optional, Hashable, Tuple, TYPE_CHECKING
|
||||
from typing import List, Any, Dict, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
from overrides import overrides
|
||||
|
||||
@@ -672,7 +672,7 @@ class GurobiTestInstanceKnapsack(PyomoTestInstanceKnapsack):
|
||||
self,
|
||||
solver: InternalSolver,
|
||||
model: Any,
|
||||
violation: Hashable,
|
||||
violation: str,
|
||||
) -> None:
|
||||
x0 = model.getVarByName("x[0]")
|
||||
model.cbLazy(x0 <= 0)
|
||||
|
||||
@@ -6,7 +6,7 @@ import logging
|
||||
import re
|
||||
import sys
|
||||
from io import StringIO
|
||||
from typing import Any, List, Dict, Optional, Tuple, Hashable
|
||||
from typing import Any, List, Dict, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pyomo
|
||||
@@ -639,5 +639,5 @@ class PyomoTestInstanceKnapsack(Instance):
|
||||
}
|
||||
|
||||
@overrides
|
||||
def get_variable_categories(self) -> Dict[str, Hashable]:
|
||||
def get_variable_categories(self) -> Dict[str, str]:
|
||||
return {f"x[{i}]": "default" for i in range(len(self.weights))}
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
from typing import Optional, Dict, Callable, Any, Union, TYPE_CHECKING, Hashable
|
||||
from typing import Optional, Dict, Callable, Any, Union, TYPE_CHECKING
|
||||
|
||||
from mypy_extensions import TypedDict
|
||||
|
||||
@@ -10,7 +10,7 @@ if TYPE_CHECKING:
|
||||
# noinspection PyUnresolvedReferences
|
||||
from miplearn.solvers.learning import InternalSolver
|
||||
|
||||
Category = Hashable
|
||||
Category = str
|
||||
IterationCallback = Callable[[], bool]
|
||||
LazyCallback = Callable[[Any, Any], None]
|
||||
SolverParams = Dict[str, Any]
|
||||
|
||||
Reference in New Issue
Block a user