Rewrite DynamicLazy.sample_xy

This commit is contained in:
2021-04-12 07:41:22 -05:00
parent bccf0e9860
commit 6f6cd3018b
12 changed files with 171 additions and 40 deletions

View File

@@ -2,7 +2,8 @@
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details.
from typing import Dict, Hashable, List, Tuple, TYPE_CHECKING
import logging
from typing import Dict, Hashable, List, Tuple, Optional
import numpy as np
from overrides import overrides
@@ -11,15 +12,11 @@ from miplearn.classifiers import Classifier
from miplearn.classifiers.threshold import Threshold
from miplearn.components import classifier_evaluation_dict
from miplearn.components.component import Component
from miplearn.features import TrainingSample
import logging
from miplearn.features import TrainingSample, Sample
from miplearn.instance.base import Instance
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from miplearn.solvers.learning import Instance
class DynamicConstraintsComponent(Component):
"""
@@ -40,9 +37,9 @@ class DynamicConstraintsComponent(Component):
self.known_cids: List[str] = []
self.attr = attr
def sample_xy_with_cids(
def sample_xy_with_cids_old(
self,
instance: "Instance",
instance: Instance,
sample: TrainingSample,
) -> Tuple[
Dict[Hashable, List[List[float]]],
@@ -78,25 +75,78 @@ class DynamicConstraintsComponent(Component):
y[category] += [[True, False]]
return x, y, cids
def sample_xy_with_cids(
self,
instance: Optional[Instance],
sample: Sample,
) -> Tuple[
Dict[Hashable, List[List[float]]],
Dict[Hashable, List[List[bool]]],
Dict[Hashable, List[str]],
]:
assert instance is not None
x: Dict[Hashable, List[List[float]]] = {}
y: Dict[Hashable, List[List[bool]]] = {}
cids: Dict[Hashable, List[str]] = {}
for cid in self.known_cids:
# Initialize categories
category = instance.get_constraint_category(cid)
if category is None:
continue
if category not in x:
x[category] = []
y[category] = []
cids[category] = []
# Features
features = []
assert sample.after_lp is not None
assert sample.after_lp.instance is not None
features.extend(sample.after_lp.instance.to_list())
features.extend(instance.get_constraint_features(cid))
for ci in features:
assert isinstance(ci, float)
x[category].append(features)
cids[category].append(cid)
# Labels
if sample.after_mip is not None:
assert sample.after_mip.extra is not None
if sample.after_mip.extra[self.attr] is not None:
if cid in sample.after_mip.extra[self.attr]:
y[category] += [[False, True]]
else:
y[category] += [[True, False]]
return x, y, cids
@overrides
def sample_xy_old(
self,
instance: "Instance",
instance: Instance,
sample: TrainingSample,
) -> Tuple[Dict, Dict]:
x, y, _ = self.sample_xy_with_cids_old(instance, sample)
return x, y
@overrides
def sample_xy(
self,
instance: Optional[Instance],
sample: Sample,
) -> Tuple[Dict, Dict]:
x, y, _ = self.sample_xy_with_cids(instance, sample)
return x, y
def sample_predict(
self,
instance: "Instance",
instance: Instance,
sample: TrainingSample,
) -> List[Hashable]:
pred: List[Hashable] = []
if len(self.known_cids) == 0:
logger.info("Classifiers not fitted. Skipping.")
return pred
x, _, cids = self.sample_xy_with_cids(instance, sample)
x, _, cids = self.sample_xy_with_cids_old(instance, sample)
for category in x.keys():
assert category in self.classifiers
assert category in self.thresholds
@@ -111,7 +161,7 @@ class DynamicConstraintsComponent(Component):
return pred
@overrides
def fit(self, training_instances: List["Instance"]) -> None:
def fit(self, training_instances: List[Instance]) -> None:
collected_cids = set()
for instance in training_instances:
instance.load()
@@ -141,7 +191,7 @@ class DynamicConstraintsComponent(Component):
@overrides
def sample_evaluate_old(
self,
instance: "Instance",
instance: Instance,
sample: TrainingSample,
) -> Dict[Hashable, Dict[str, float]]:
assert getattr(sample, self.attr) is not None