Rewrite DynamicLazy.sample_xy

This commit is contained in:
2021-04-12 07:41:22 -05:00
parent bccf0e9860
commit 6f6cd3018b
12 changed files with 171 additions and 40 deletions

View File

@@ -2,7 +2,7 @@
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details.
from typing import Any, List, TYPE_CHECKING, Tuple, Dict, Hashable
from typing import Any, List, TYPE_CHECKING, Tuple, Dict, Hashable, Optional
import numpy as np
from overrides import EnforceOverrides
@@ -119,7 +119,11 @@ class Component:
"""
pass
def sample_xy(self, sample: Sample) -> Tuple[Dict, Dict]:
def sample_xy(
self,
instance: Optional[Instance],
sample: Sample,
) -> Tuple[Dict, Dict]:
"""
Returns a pair of x and y dictionaries containing, respectively, the matrices
of ML features and the labels for the sample. If the training sample does not

View File

@@ -2,7 +2,8 @@
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details.
from typing import Dict, Hashable, List, Tuple, TYPE_CHECKING
import logging
from typing import Dict, Hashable, List, Tuple, Optional
import numpy as np
from overrides import overrides
@@ -11,15 +12,11 @@ from miplearn.classifiers import Classifier
from miplearn.classifiers.threshold import Threshold
from miplearn.components import classifier_evaluation_dict
from miplearn.components.component import Component
from miplearn.features import TrainingSample
import logging
from miplearn.features import TrainingSample, Sample
from miplearn.instance.base import Instance
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from miplearn.solvers.learning import Instance
class DynamicConstraintsComponent(Component):
"""
@@ -40,9 +37,9 @@ class DynamicConstraintsComponent(Component):
self.known_cids: List[str] = []
self.attr = attr
def sample_xy_with_cids(
def sample_xy_with_cids_old(
self,
instance: "Instance",
instance: Instance,
sample: TrainingSample,
) -> Tuple[
Dict[Hashable, List[List[float]]],
@@ -78,25 +75,78 @@ class DynamicConstraintsComponent(Component):
y[category] += [[True, False]]
return x, y, cids
def sample_xy_with_cids(
self,
instance: Optional[Instance],
sample: Sample,
) -> Tuple[
Dict[Hashable, List[List[float]]],
Dict[Hashable, List[List[bool]]],
Dict[Hashable, List[str]],
]:
assert instance is not None
x: Dict[Hashable, List[List[float]]] = {}
y: Dict[Hashable, List[List[bool]]] = {}
cids: Dict[Hashable, List[str]] = {}
for cid in self.known_cids:
# Initialize categories
category = instance.get_constraint_category(cid)
if category is None:
continue
if category not in x:
x[category] = []
y[category] = []
cids[category] = []
# Features
features = []
assert sample.after_lp is not None
assert sample.after_lp.instance is not None
features.extend(sample.after_lp.instance.to_list())
features.extend(instance.get_constraint_features(cid))
for ci in features:
assert isinstance(ci, float)
x[category].append(features)
cids[category].append(cid)
# Labels
if sample.after_mip is not None:
assert sample.after_mip.extra is not None
if sample.after_mip.extra[self.attr] is not None:
if cid in sample.after_mip.extra[self.attr]:
y[category] += [[False, True]]
else:
y[category] += [[True, False]]
return x, y, cids
@overrides
def sample_xy_old(
self,
instance: "Instance",
instance: Instance,
sample: TrainingSample,
) -> Tuple[Dict, Dict]:
x, y, _ = self.sample_xy_with_cids_old(instance, sample)
return x, y
@overrides
def sample_xy(
self,
instance: Optional[Instance],
sample: Sample,
) -> Tuple[Dict, Dict]:
x, y, _ = self.sample_xy_with_cids(instance, sample)
return x, y
def sample_predict(
self,
instance: "Instance",
instance: Instance,
sample: TrainingSample,
) -> List[Hashable]:
pred: List[Hashable] = []
if len(self.known_cids) == 0:
logger.info("Classifiers not fitted. Skipping.")
return pred
x, _, cids = self.sample_xy_with_cids(instance, sample)
x, _, cids = self.sample_xy_with_cids_old(instance, sample)
for category in x.keys():
assert category in self.classifiers
assert category in self.thresholds
@@ -111,7 +161,7 @@ class DynamicConstraintsComponent(Component):
return pred
@overrides
def fit(self, training_instances: List["Instance"]) -> None:
def fit(self, training_instances: List[Instance]) -> None:
collected_cids = set()
for instance in training_instances:
instance.load()
@@ -141,7 +191,7 @@ class DynamicConstraintsComponent(Component):
@overrides
def sample_evaluate_old(
self,
instance: "Instance",
instance: Instance,
sample: TrainingSample,
) -> Dict[Hashable, Dict[str, float]]:
assert getattr(sample, self.attr) is not None

View File

@@ -3,7 +3,7 @@
# Released under the modified BSD license. See COPYING.md for more details.
import logging
from typing import Dict, List, TYPE_CHECKING, Hashable, Tuple, Any
from typing import Dict, List, TYPE_CHECKING, Hashable, Tuple, Any, Optional
import numpy as np
from overrides import overrides
@@ -14,7 +14,7 @@ from miplearn.classifiers.counting import CountingClassifier
from miplearn.classifiers.threshold import MinProbabilityThreshold, Threshold
from miplearn.components.component import Component
from miplearn.components.dynamic_common import DynamicConstraintsComponent
from miplearn.features import TrainingSample, Features
from miplearn.features import TrainingSample, Features, Sample
from miplearn.types import LearningSolveStats
logger = logging.getLogger(__name__)
@@ -95,20 +95,28 @@ class DynamicLazyConstraintsComponent(Component):
@overrides
def sample_xy_old(
self,
instance: "Instance",
instance: Instance,
sample: TrainingSample,
) -> Tuple[Dict, Dict]:
return self.dynamic.sample_xy_old(instance, sample)
@overrides
def sample_xy(
self,
instance: Optional[Instance],
sample: Sample,
) -> Tuple[Dict, Dict]:
return self.dynamic.sample_xy(instance, sample)
def sample_predict(
self,
instance: "Instance",
instance: Instance,
sample: TrainingSample,
) -> List[Hashable]:
return self.dynamic.sample_predict(instance, sample)
@overrides
def fit(self, training_instances: List["Instance"]) -> None:
def fit(self, training_instances: List[Instance]) -> None:
self.dynamic.fit(training_instances)
@overrides
@@ -122,7 +130,7 @@ class DynamicLazyConstraintsComponent(Component):
@overrides
def sample_evaluate_old(
self,
instance: "Instance",
instance: Instance,
sample: TrainingSample,
) -> Dict[Hashable, Dict[str, float]]:
return self.dynamic.sample_evaluate_old(instance, sample)

View File

@@ -3,23 +3,24 @@
# Released under the modified BSD license. See COPYING.md for more details.
import logging
from typing import Any, TYPE_CHECKING, Hashable, Set, Tuple, Dict, List
from typing import Any, TYPE_CHECKING, Hashable, Set, Tuple, Dict, List, Optional
import numpy as np
from overrides import overrides
from miplearn.instance.base import Instance
from miplearn.classifiers import Classifier
from miplearn.classifiers.counting import CountingClassifier
from miplearn.classifiers.threshold import Threshold, MinProbabilityThreshold
from miplearn.components.component import Component
from miplearn.components.dynamic_common import DynamicConstraintsComponent
from miplearn.features import Features, TrainingSample
from miplearn.features import Features, TrainingSample, Sample
from miplearn.types import LearningSolveStats
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from miplearn.solvers.learning import LearningSolver, Instance
from miplearn.solvers.learning import LearningSolver
class UserCutsComponent(Component):
@@ -103,6 +104,14 @@ class UserCutsComponent(Component):
) -> Tuple[Dict, Dict]:
return self.dynamic.sample_xy_old(instance, sample)
@overrides
def sample_xy(
self,
instance: Optional[Instance],
sample: Sample,
) -> Tuple[Dict, Dict]:
return self.dynamic.sample_xy(instance, sample)
def sample_predict(
self,
instance: "Instance",

View File

@@ -3,7 +3,7 @@
# Released under the modified BSD license. See COPYING.md for more details.
import logging
from typing import List, Dict, Any, TYPE_CHECKING, Tuple, Hashable
from typing import List, Dict, Any, TYPE_CHECKING, Tuple, Hashable, Optional
import numpy as np
from overrides import overrides
@@ -101,6 +101,7 @@ class ObjectiveValueComponent(Component):
@overrides
def sample_xy(
self,
_: Optional[Instance],
sample: Sample,
) -> Tuple[Dict[Hashable, List[List[float]]], Dict[Hashable, List[List[float]]]]:
# Instance features

View File

@@ -10,6 +10,7 @@ from typing import (
Any,
TYPE_CHECKING,
Tuple,
Optional,
)
import numpy as np
@@ -182,6 +183,7 @@ class PrimalSolutionComponent(Component):
@overrides
def sample_xy(
self,
_: Optional[Instance],
sample: Sample,
) -> Tuple[Dict[Category, List[List[float]]], Dict[Category, List[List[float]]]]:
x: Dict = {}

View File

@@ -3,11 +3,12 @@
# Released under the modified BSD license. See COPYING.md for more details.
import logging
from typing import Dict, Tuple, List, Hashable, Any, TYPE_CHECKING, Set
from typing import Dict, Tuple, List, Hashable, Any, TYPE_CHECKING, Set, Optional
import numpy as np
from overrides import overrides
from miplearn.instance.base import Instance
from miplearn.classifiers import Classifier
from miplearn.classifiers.counting import CountingClassifier
from miplearn.classifiers.threshold import MinProbabilityThreshold, Threshold
@@ -18,7 +19,7 @@ from miplearn.types import LearningSolveStats
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from miplearn.solvers.learning import LearningSolver, Instance
from miplearn.solvers.learning import LearningSolver
class LazyConstraint:
@@ -202,6 +203,7 @@ class StaticLazyConstraintsComponent(Component):
@overrides
def sample_xy(
self,
_: Optional[Instance],
sample: Sample,
) -> Tuple[Dict[Hashable, List[List[float]]], Dict[Hashable, List[List[float]]]]:
x: Dict = {}

View File

@@ -98,7 +98,7 @@ class Instance(ABC, EnforceOverrides):
"""
return "default"
def get_constraint_features(self, cid: str) -> Optional[List[float]]:
def get_constraint_features(self, cid: str) -> List[float]:
return [0.0]
def get_constraint_category(self, cid: str) -> Optional[Hashable]: