Enforce more overrides

This commit is contained in:
2021-04-07 12:01:05 -05:00
parent 1cf6124757
commit 96093a9b8e
8 changed files with 42 additions and 1 deletions

View File

@@ -6,6 +6,7 @@ import logging
from typing import Any, TYPE_CHECKING, Hashable, Set, Tuple, Dict, List
import numpy as np
from overrides import overrides
from miplearn.classifiers import Classifier
from miplearn.classifiers.counting import CountingClassifier
@@ -35,6 +36,7 @@ class UserCutsComponent(Component):
self.enforced: Set[Hashable] = set()
self.n_added_in_callback = 0
@overrides
def before_solve_mip(
self,
solver: "LearningSolver",
@@ -55,6 +57,7 @@ class UserCutsComponent(Component):
solver.internal_solver.add_constraint(cobj)
stats["UserCuts: Added ahead-of-time"] = len(cids)
@overrides
def user_cut_cb(
self,
solver: "LearningSolver",
@@ -78,6 +81,7 @@ class UserCutsComponent(Component):
if len(cids) > 0:
logger.debug(f"Added {len(cids)} violated user cuts")
@overrides
def after_solve_mip(
self,
solver: "LearningSolver",
@@ -93,6 +97,7 @@ class UserCutsComponent(Component):
# Delegate ML methods to self.dynamic
# -------------------------------------------------------------------
@overrides
def sample_xy(
self,
instance: "Instance",
@@ -107,9 +112,11 @@ class UserCutsComponent(Component):
) -> List[Hashable]:
return self.dynamic.sample_predict(instance, sample)
@overrides
def fit(self, training_instances: List["Instance"]) -> None:
self.dynamic.fit(training_instances)
@overrides
def fit_xy(
self,
x: Dict[Hashable, np.ndarray],
@@ -117,6 +124,7 @@ class UserCutsComponent(Component):
) -> None:
self.dynamic.fit_xy(x, y)
@overrides
def sample_evaluate(
self,
instance: "Instance",