Lazy: Simplify method signature; switch to AbstractModel

This commit is contained in:
2023-10-27 09:14:51 -05:00
parent 7079a36203
commit a42cd5ae35
5 changed files with 26 additions and 18 deletions

View File

@@ -9,16 +9,10 @@ from sklearn.preprocessing import MultiLabelBinarizer
from miplearn.extractors.abstract import FeaturesExtractor
from miplearn.h5 import H5File
from miplearn.solvers.gurobi import GurobiModel
from miplearn.solvers.abstract import AbstractModel
logger = logging.getLogger(__name__)
# TODO: Replace GurobiModel by AbstractModel
# TODO: fix_violations: remove model.inner
# TODO: fix_violations: remove `where`
# TODO: Write documentation
# TODO: Implement ExpertLazyConstrComponent
class MemorizingLazyConstrComponent:
def __init__(self, clf: Any, extractor: FeaturesExtractor) -> None:
@@ -79,13 +73,14 @@ class MemorizingLazyConstrComponent:
def before_mip(
self,
test_h5: str,
model: GurobiModel,
model: AbstractModel,
stats: Dict[str, Any],
) -> None:
assert self.constrs_ is not None
if model.lazy_enforce is None:
return
assert self.constrs_ is not None
# Read features
with H5File(test_h5, "r") as h5:
x_sample = self.extractor.get_instance_features(h5)
@@ -101,5 +96,5 @@ class MemorizingLazyConstrComponent:
# Enforce constraints
violations = [self.constrs_[i] for (i, yi) in enumerate(y) if yi > 0.5]
logger.info(f"Enforcing {len(violations)} constraints ahead-of-time...")
model.lazy_enforce(model, violations, "aot")
model.lazy_enforce(model, violations)
stats["Lazy Constraints: AOT"] = len(violations)