|
|
@ -1,7 +1,7 @@
|
|
|
|
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
|
|
|
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
|
|
|
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
|
|
|
|
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
|
|
|
|
# Released under the modified BSD license. See COPYING.md for more details.
|
|
|
|
# Released under the modified BSD license. See COPYING.md for more details.
|
|
|
|
|
|
|
|
import json
|
|
|
|
import logging
|
|
|
|
import logging
|
|
|
|
from typing import Dict, List, Tuple, Optional, Any, Set
|
|
|
|
from typing import Dict, List, Tuple, Optional, Any, Set
|
|
|
|
|
|
|
|
|
|
|
@ -36,7 +36,7 @@ class DynamicConstraintsComponent(Component):
|
|
|
|
self.classifier_prototype: Classifier = classifier
|
|
|
|
self.classifier_prototype: Classifier = classifier
|
|
|
|
self.classifiers: Dict[ConstraintCategory, Classifier] = {}
|
|
|
|
self.classifiers: Dict[ConstraintCategory, Classifier] = {}
|
|
|
|
self.thresholds: Dict[ConstraintCategory, Threshold] = {}
|
|
|
|
self.thresholds: Dict[ConstraintCategory, Threshold] = {}
|
|
|
|
self.known_cids: List[ConstraintName] = []
|
|
|
|
self.known_violations: Dict[ConstraintName, Any] = {}
|
|
|
|
self.attr = attr
|
|
|
|
self.attr = attr
|
|
|
|
|
|
|
|
|
|
|
|
def sample_xy_with_cids(
|
|
|
|
def sample_xy_with_cids(
|
|
|
@ -48,18 +48,19 @@ class DynamicConstraintsComponent(Component):
|
|
|
|
Dict[ConstraintCategory, List[List[bool]]],
|
|
|
|
Dict[ConstraintCategory, List[List[bool]]],
|
|
|
|
Dict[ConstraintCategory, List[ConstraintName]],
|
|
|
|
Dict[ConstraintCategory, List[ConstraintName]],
|
|
|
|
]:
|
|
|
|
]:
|
|
|
|
if len(self.known_cids) == 0:
|
|
|
|
if len(self.known_violations) == 0:
|
|
|
|
return {}, {}, {}
|
|
|
|
return {}, {}, {}
|
|
|
|
assert instance is not None
|
|
|
|
assert instance is not None
|
|
|
|
x: Dict[ConstraintCategory, List[List[float]]] = {}
|
|
|
|
x: Dict[ConstraintCategory, List[List[float]]] = {}
|
|
|
|
y: Dict[ConstraintCategory, List[List[bool]]] = {}
|
|
|
|
y: Dict[ConstraintCategory, List[List[bool]]] = {}
|
|
|
|
cids: Dict[ConstraintCategory, List[ConstraintName]] = {}
|
|
|
|
cids: Dict[ConstraintCategory, List[ConstraintName]] = {}
|
|
|
|
known_cids = np.array(self.known_cids, dtype="S")
|
|
|
|
known_cids = np.array(sorted(list(self.known_violations.keys())), dtype="S")
|
|
|
|
|
|
|
|
|
|
|
|
enforced_cids = None
|
|
|
|
enforced_cids = None
|
|
|
|
enforced_cids_np = sample.get_array(self.attr)
|
|
|
|
enforced_encoded = sample.get_scalar(self.attr)
|
|
|
|
if enforced_cids_np is not None:
|
|
|
|
if enforced_encoded is not None:
|
|
|
|
enforced_cids = list(enforced_cids_np)
|
|
|
|
enforced = self.decode(enforced_encoded)
|
|
|
|
|
|
|
|
enforced_cids = list(enforced.keys())
|
|
|
|
|
|
|
|
|
|
|
|
# Get user-provided constraint features
|
|
|
|
# Get user-provided constraint features
|
|
|
|
(
|
|
|
|
(
|
|
|
@ -100,11 +101,10 @@ class DynamicConstraintsComponent(Component):
|
|
|
|
@overrides
|
|
|
|
@overrides
|
|
|
|
def pre_fit(self, pre: List[Any]) -> None:
|
|
|
|
def pre_fit(self, pre: List[Any]) -> None:
|
|
|
|
assert pre is not None
|
|
|
|
assert pre is not None
|
|
|
|
known_cids: Set = set()
|
|
|
|
self.known_violations.clear()
|
|
|
|
for cids in pre:
|
|
|
|
for violations in pre:
|
|
|
|
known_cids |= set(list(cids))
|
|
|
|
for (vname, vdata) in violations.items():
|
|
|
|
self.known_cids.clear()
|
|
|
|
self.known_violations[vname] = vdata
|
|
|
|
self.known_cids.extend(sorted(known_cids))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sample_predict(
|
|
|
|
def sample_predict(
|
|
|
|
self,
|
|
|
|
self,
|
|
|
@ -112,7 +112,7 @@ class DynamicConstraintsComponent(Component):
|
|
|
|
sample: Sample,
|
|
|
|
sample: Sample,
|
|
|
|
) -> List[ConstraintName]:
|
|
|
|
) -> List[ConstraintName]:
|
|
|
|
pred: List[ConstraintName] = []
|
|
|
|
pred: List[ConstraintName] = []
|
|
|
|
if len(self.known_cids) == 0:
|
|
|
|
if len(self.known_violations) == 0:
|
|
|
|
logger.info("Classifiers not fitted. Skipping.")
|
|
|
|
logger.info("Classifiers not fitted. Skipping.")
|
|
|
|
return pred
|
|
|
|
return pred
|
|
|
|
x, _, cids = self.sample_xy_with_cids(instance, sample)
|
|
|
|
x, _, cids = self.sample_xy_with_cids(instance, sample)
|
|
|
@ -131,7 +131,9 @@ class DynamicConstraintsComponent(Component):
|
|
|
|
|
|
|
|
|
|
|
|
@overrides
|
|
|
|
@overrides
|
|
|
|
def pre_sample_xy(self, instance: Instance, sample: Sample) -> Any:
|
|
|
|
def pre_sample_xy(self, instance: Instance, sample: Sample) -> Any:
|
|
|
|
return sample.get_array(self.attr)
|
|
|
|
attr_encoded = sample.get_scalar(self.attr)
|
|
|
|
|
|
|
|
assert attr_encoded is not None
|
|
|
|
|
|
|
|
return self.decode(attr_encoded)
|
|
|
|
|
|
|
|
|
|
|
|
@overrides
|
|
|
|
@overrides
|
|
|
|
def fit_xy(
|
|
|
|
def fit_xy(
|
|
|
@ -153,11 +155,13 @@ class DynamicConstraintsComponent(Component):
|
|
|
|
instance: Instance,
|
|
|
|
instance: Instance,
|
|
|
|
sample: Sample,
|
|
|
|
sample: Sample,
|
|
|
|
) -> Dict[str, float]:
|
|
|
|
) -> Dict[str, float]:
|
|
|
|
actual = sample.get_array(self.attr)
|
|
|
|
attr_encoded = sample.get_scalar(self.attr)
|
|
|
|
assert actual is not None
|
|
|
|
assert attr_encoded is not None
|
|
|
|
|
|
|
|
actual_violations = DynamicConstraintsComponent.decode(attr_encoded)
|
|
|
|
|
|
|
|
actual = set(actual_violations.keys())
|
|
|
|
pred = set(self.sample_predict(instance, sample))
|
|
|
|
pred = set(self.sample_predict(instance, sample))
|
|
|
|
tp, tn, fp, fn = 0, 0, 0, 0
|
|
|
|
tp, tn, fp, fn = 0, 0, 0, 0
|
|
|
|
for cid in self.known_cids:
|
|
|
|
for cid in self.known_violations.keys():
|
|
|
|
if cid in pred:
|
|
|
|
if cid in pred:
|
|
|
|
if cid in actual:
|
|
|
|
if cid in actual:
|
|
|
|
tp += 1
|
|
|
|
tp += 1
|
|
|
@ -169,3 +173,12 @@ class DynamicConstraintsComponent(Component):
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
tn += 1
|
|
|
|
tn += 1
|
|
|
|
return classifier_evaluation_dict(tp=tp, tn=tn, fp=fp, fn=fn)
|
|
|
|
return classifier_evaluation_dict(tp=tp, tn=tn, fp=fp, fn=fn)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
|
|
def encode(violations: Dict[ConstraintName, Any]) -> str:
|
|
|
|
|
|
|
|
return json.dumps({k.decode(): v for (k, v) in violations.items()})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
|
|
def decode(violations_encoded: str) -> Dict[ConstraintName, Any]:
|
|
|
|
|
|
|
|
violations = json.loads(violations_encoded)
|
|
|
|
|
|
|
|
return {k.encode(): v for (k, v) in violations.items()}
|
|
|
|