Components: Switch from factory methods to prototype objects

This commit is contained in:
2021-04-01 08:34:56 -05:00
parent 59c734f2a1
commit bc8fe4dc98
9 changed files with 43 additions and 34 deletions

View File

@@ -4,7 +4,6 @@
import logging
import sys
from copy import deepcopy
from typing import Any, Dict
import numpy as np
@@ -29,6 +28,7 @@ class UserCutsComponent(Component):
classifier: Classifier = CountingClassifier(),
threshold: float = 0.05,
):
assert isinstance(classifier, Classifier)
self.threshold: float = threshold
self.classifier_prototype: Classifier = classifier
self.classifiers: Dict[Any, Classifier] = {}
@@ -63,7 +63,7 @@ class UserCutsComponent(Component):
continue
for v in instance.found_violated_user_cuts:
if v not in self.classifiers:
self.classifiers[v] = deepcopy(self.classifier_prototype)
self.classifiers[v] = self.classifier_prototype.clone()
violation_to_instance_idx[v] = []
violation_to_instance_idx[v] += [idx]

View File

@@ -4,7 +4,6 @@
import logging
import sys
from copy import deepcopy
from typing import Any, Dict
import numpy as np
@@ -29,6 +28,7 @@ class DynamicLazyConstraintsComponent(Component):
classifier: Classifier = CountingClassifier(),
threshold: float = 0.05,
):
assert isinstance(classifier, Classifier)
self.threshold: float = threshold
self.classifier_prototype: Classifier = classifier
self.classifiers: Dict[Any, Classifier] = {}
@@ -75,7 +75,7 @@ class DynamicLazyConstraintsComponent(Component):
if isinstance(v, list):
v = tuple(v)
if v not in self.classifiers:
self.classifiers[v] = deepcopy(self.classifier_prototype)
self.classifiers[v] = self.classifier_prototype.clone()
violation_to_instance_idx[v] = []
violation_to_instance_idx[v] += [idx]

View File

@@ -4,12 +4,12 @@
import logging
import sys
from copy import deepcopy
from typing import Any, Dict, Tuple, Optional
from typing import Dict, Tuple, Optional
import numpy as np
from tqdm.auto import tqdm
from miplearn import Classifier
from miplearn.classifiers.counting import CountingClassifier
from miplearn.components.component import Component
from miplearn.types import TrainingSample, Features
@@ -32,6 +32,7 @@ class StaticLazyConstraintsComponent(Component):
large_gap=1e-2,
violation_tolerance=-0.5,
):
assert isinstance(classifier, Classifier)
self.threshold = threshold
self.classifier_prototype = classifier
self.classifiers = {}
@@ -120,7 +121,7 @@ class StaticLazyConstraintsComponent(Component):
x.keys(), desc="Fit (lazy)", disable=not sys.stdout.isatty()
):
if category not in self.classifiers:
self.classifiers[category] = deepcopy(self.classifier_prototype)
self.classifiers[category] = self.classifier_prototype.clone()
self.classifiers[category].fit(x[category], y[category])
def predict(self, instance):

View File

@@ -3,7 +3,7 @@
# Released under the modified BSD license. See COPYING.md for more details.
import logging
from typing import List, Dict, Union, Callable, Optional, Any, TYPE_CHECKING, Tuple
from typing import List, Dict, Union, Optional, Any, TYPE_CHECKING, Tuple
import numpy as np
from sklearn.linear_model import LinearRegression
@@ -16,10 +16,11 @@ from sklearn.metrics import (
)
from miplearn.classifiers import Regressor
from miplearn.classifiers.sklearn import ScikitLearnRegressor
from miplearn.components.component import Component
from miplearn.extractors import InstanceIterator
from miplearn.instance import Instance
from miplearn.types import MIPSolveStats, TrainingSample, LearningSolveStats, Features
from miplearn.types import TrainingSample, LearningSolveStats, Features
if TYPE_CHECKING:
from miplearn.solvers.learning import LearningSolver
@@ -34,13 +35,15 @@ class ObjectiveValueComponent(Component):
def __init__(
self,
lb_regressor: Callable[[], Regressor] = LinearRegression,
ub_regressor: Callable[[], Regressor] = LinearRegression,
lb_regressor: Regressor = ScikitLearnRegressor(LinearRegression()),
ub_regressor: Regressor = ScikitLearnRegressor(LinearRegression()),
) -> None:
assert isinstance(lb_regressor, Regressor)
assert isinstance(ub_regressor, Regressor)
self.ub_regressor: Optional[Regressor] = None
self.lb_regressor: Optional[Regressor] = None
self.lb_regressor_factory = lb_regressor
self.ub_regressor_factory = ub_regressor
self.lb_regressor_prototype = lb_regressor
self.ub_regressor_prototype = ub_regressor
self._predicted_ub: Optional[float] = None
self._predicted_lb: Optional[float] = None
@@ -77,8 +80,8 @@ class ObjectiveValueComponent(Component):
stats["Objective: predicted LB"] = self._predicted_lb
def fit(self, training_instances: Union[List[str], List[Instance]]) -> None:
self.lb_regressor = self.lb_regressor_factory()
self.ub_regressor = self.ub_regressor_factory()
self.lb_regressor = self.lb_regressor_prototype.clone()
self.ub_regressor = self.ub_regressor_prototype.clone()
logger.debug("Extracting features...")
x_train = self.x(training_instances)
y_train = self.y(training_instances)

View File

@@ -50,18 +50,18 @@ class PrimalSolutionComponent(Component):
def __init__(
self,
classifier: Callable[[], Classifier] = lambda: AdaptiveClassifier(),
classifier: Classifier = AdaptiveClassifier(),
mode: str = "exact",
threshold: Callable[[], Threshold] = lambda: MinPrecisionThreshold(
[0.98, 0.98]
),
threshold: Threshold = MinPrecisionThreshold([0.98, 0.98]),
) -> None:
assert isinstance(classifier, Classifier)
assert isinstance(threshold, Threshold)
assert mode in ["exact", "heuristic"]
self.mode = mode
self.classifiers: Dict[Hashable, Classifier] = {}
self.thresholds: Dict[Hashable, Threshold] = {}
self.threshold_factory = threshold
self.classifier_factory = classifier
self.threshold_prototype = threshold
self.classifier_prototype = classifier
self.stats: Dict[str, float] = {}
self._n_free = 0
self._n_zero = 0
@@ -114,8 +114,8 @@ class PrimalSolutionComponent(Component):
y: Dict[str, np.ndarray],
) -> None:
for category in x.keys():
clf = self.classifier_factory()
thr = self.threshold_factory()
clf = self.classifier_prototype.clone()
thr = self.threshold_prototype.clone()
clf.fit(x[category], y[category])
thr.fit(clf, x[category], y[category])
self.classifiers[category] = clf