mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-08 02:18:51 -06:00
Add type annotations to components
This commit is contained in:
@@ -5,10 +5,12 @@
|
||||
import logging
|
||||
import sys
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from miplearn.classifiers import Classifier
|
||||
from miplearn.classifiers.counting import CountingClassifier
|
||||
from miplearn.components import classifier_evaluation_dict
|
||||
from miplearn.components.component import Component
|
||||
@@ -24,15 +26,12 @@ class UserCutsComponent(Component):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
classifier=CountingClassifier(),
|
||||
threshold=0.05,
|
||||
classifier: Classifier = CountingClassifier(),
|
||||
threshold: float = 0.05,
|
||||
):
|
||||
self.violations = set()
|
||||
self.count = {}
|
||||
self.n_samples = 0
|
||||
self.threshold = threshold
|
||||
self.classifier_prototype = classifier
|
||||
self.classifiers = {}
|
||||
self.threshold: float = threshold
|
||||
self.classifier_prototype: Classifier = classifier
|
||||
self.classifiers: Dict[Any, Classifier] = {}
|
||||
|
||||
def before_solve(self, solver, instance, model):
|
||||
instance.found_violated_user_cuts = []
|
||||
|
||||
Reference in New Issue
Block a user