Components: Switch from factory methods to prototype objects

This commit is contained in:
2021-04-01 08:34:56 -05:00
parent 59c734f2a1
commit bc8fe4dc98
9 changed files with 43 additions and 34 deletions

View File

@@ -3,7 +3,7 @@
# Released under the modified BSD license. See COPYING.md for more details.
import logging
from typing import List, Dict, Union, Callable, Optional, Any, TYPE_CHECKING, Tuple
from typing import List, Dict, Union, Optional, Any, TYPE_CHECKING, Tuple
import numpy as np
from sklearn.linear_model import LinearRegression
@@ -16,10 +16,11 @@ from sklearn.metrics import (
)
from miplearn.classifiers import Regressor
from miplearn.classifiers.sklearn import ScikitLearnRegressor
from miplearn.components.component import Component
from miplearn.extractors import InstanceIterator
from miplearn.instance import Instance
from miplearn.types import MIPSolveStats, TrainingSample, LearningSolveStats, Features
from miplearn.types import TrainingSample, LearningSolveStats, Features
if TYPE_CHECKING:
from miplearn.solvers.learning import LearningSolver
@@ -34,13 +35,15 @@ class ObjectiveValueComponent(Component):
def __init__(
self,
lb_regressor: Callable[[], Regressor] = LinearRegression,
ub_regressor: Callable[[], Regressor] = LinearRegression,
lb_regressor: Regressor = ScikitLearnRegressor(LinearRegression()),
ub_regressor: Regressor = ScikitLearnRegressor(LinearRegression()),
) -> None:
assert isinstance(lb_regressor, Regressor)
assert isinstance(ub_regressor, Regressor)
self.ub_regressor: Optional[Regressor] = None
self.lb_regressor: Optional[Regressor] = None
self.lb_regressor_factory = lb_regressor
self.ub_regressor_factory = ub_regressor
self.lb_regressor_prototype = lb_regressor
self.ub_regressor_prototype = ub_regressor
self._predicted_ub: Optional[float] = None
self._predicted_lb: Optional[float] = None
@@ -77,8 +80,8 @@ class ObjectiveValueComponent(Component):
stats["Objective: predicted LB"] = self._predicted_lb
def fit(self, training_instances: Union[List[str], List[Instance]]) -> None:
self.lb_regressor = self.lb_regressor_factory()
self.ub_regressor = self.ub_regressor_factory()
self.lb_regressor = self.lb_regressor_prototype.clone()
self.ub_regressor = self.ub_regressor_prototype.clone()
logger.debug("Extracting features...")
x_train = self.x(training_instances)
y_train = self.y(training_instances)