mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-08 10:28:52 -06:00
Components: Switch from factory methods to prototype objects
This commit is contained in:
@@ -4,12 +4,12 @@
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, Tuple, Optional
|
||||
from typing import Dict, Tuple, Optional
|
||||
|
||||
import numpy as np
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from miplearn import Classifier
|
||||
from miplearn.classifiers.counting import CountingClassifier
|
||||
from miplearn.components.component import Component
|
||||
from miplearn.types import TrainingSample, Features
|
||||
@@ -32,6 +32,7 @@ class StaticLazyConstraintsComponent(Component):
|
||||
large_gap=1e-2,
|
||||
violation_tolerance=-0.5,
|
||||
):
|
||||
assert isinstance(classifier, Classifier)
|
||||
self.threshold = threshold
|
||||
self.classifier_prototype = classifier
|
||||
self.classifiers = {}
|
||||
@@ -120,7 +121,7 @@ class StaticLazyConstraintsComponent(Component):
|
||||
x.keys(), desc="Fit (lazy)", disable=not sys.stdout.isatty()
|
||||
):
|
||||
if category not in self.classifiers:
|
||||
self.classifiers[category] = deepcopy(self.classifier_prototype)
|
||||
self.classifiers[category] = self.classifier_prototype.clone()
|
||||
self.classifiers[category].fit(x[category], y[category])
|
||||
|
||||
def predict(self, instance):
|
||||
|
||||
Reference in New Issue
Block a user