LazyDynamic: Rewrite fit method

This commit is contained in:
2021-04-06 06:28:23 -05:00
parent 6e326d5d6e
commit bb91c83187
3 changed files with 184 additions and 4 deletions

View File

@@ -4,7 +4,7 @@
import logging
import sys
from typing import Any, Dict
from typing import Any, Dict, List, TYPE_CHECKING, Set, Hashable
import numpy as np
from tqdm.auto import tqdm
@@ -14,9 +14,13 @@ from miplearn.classifiers.counting import CountingClassifier
from miplearn.components import classifier_evaluation_dict
from miplearn.components.component import Component
from miplearn.extractors import InstanceFeaturesExtractor
from miplearn.features import TrainingSample
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from miplearn.solvers.learning import LearningSolver, Instance
class DynamicLazyConstraintsComponent(Component):
"""
@@ -32,6 +36,7 @@ class DynamicLazyConstraintsComponent(Component):
self.threshold: float = threshold
self.classifier_prototype: Classifier = classifier
self.classifiers: Dict[Any, Classifier] = {}
self.known_cids: List[str] = []
def before_solve_mip(
self,
@@ -119,3 +124,50 @@ class DynamicLazyConstraintsComponent(Component):
fn = len(pred_negative & condition_positive)
results[idx] = classifier_evaluation_dict(tp, tn, fp, fn)
return results
def fit_new(self, training_instances: List["Instance"]) -> None:
# Update known_cids
self.known_cids.clear()
for instance in training_instances:
for sample in instance.training_data:
if sample.lazy_enforced is None:
continue
self.known_cids += list(sample.lazy_enforced)
self.known_cids = sorted(set(self.known_cids))
# Build x and y matrices
x: Dict[Hashable, List[List[float]]] = {}
y: Dict[Hashable, List[List[bool]]] = {}
for instance in training_instances:
for sample in instance.training_data:
if sample.lazy_enforced is None:
continue
for cid in self.known_cids:
category = instance.get_constraint_category(cid)
if category is None:
continue
if category not in x:
x[category] = []
y[category] = []
assert instance.features.instance is not None
assert instance.features.instance.user_features is not None
cfeatures = instance.get_constraint_features(cid)
assert cfeatures is not None
assert isinstance(cfeatures, list)
for ci in cfeatures:
assert isinstance(ci, float)
f = list(instance.features.instance.user_features)
f += cfeatures
x[category] += [f]
if cid in sample.lazy_enforced:
y[category] += [[False, True]]
else:
y[category] += [[True, False]]
# Train classifiers
for category in x.keys():
self.classifiers[category] = self.classifier_prototype.clone()
self.classifiers[category].fit(
np.array(x[category]),
np.array(y[category]),
)

View File

@@ -119,7 +119,7 @@ class Instance(ABC):
def get_constraint_features(self, cid: str) -> Optional[List[float]]:
return [0.0]
def get_constraint_category(self, cid: str) -> Optional[str]:
def get_constraint_category(self, cid: str) -> Optional[Hashable]:
return cid
def has_static_lazy_constraints(self) -> bool:
@@ -243,7 +243,7 @@ class PickleGzInstance(Instance):
return self.instance.get_constraint_features(cid)
@lazy_load
def get_constraint_category(self, cid: str) -> Optional[str]:
def get_constraint_category(self, cid: str) -> Optional[Hashable]:
assert self.instance is not None
return self.instance.get_constraint_category(cid)