mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-08 02:18:51 -06:00
Components: Switch from factory methods to prototype objects
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
import logging
|
||||
from typing import List, Dict, Union, Callable, Optional, Any, TYPE_CHECKING, Tuple
|
||||
from typing import List, Dict, Union, Optional, Any, TYPE_CHECKING, Tuple
|
||||
|
||||
import numpy as np
|
||||
from sklearn.linear_model import LinearRegression
|
||||
@@ -16,10 +16,11 @@ from sklearn.metrics import (
|
||||
)
|
||||
|
||||
from miplearn.classifiers import Regressor
|
||||
from miplearn.classifiers.sklearn import ScikitLearnRegressor
|
||||
from miplearn.components.component import Component
|
||||
from miplearn.extractors import InstanceIterator
|
||||
from miplearn.instance import Instance
|
||||
from miplearn.types import MIPSolveStats, TrainingSample, LearningSolveStats, Features
|
||||
from miplearn.types import TrainingSample, LearningSolveStats, Features
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from miplearn.solvers.learning import LearningSolver
|
||||
@@ -34,13 +35,15 @@ class ObjectiveValueComponent(Component):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lb_regressor: Callable[[], Regressor] = LinearRegression,
|
||||
ub_regressor: Callable[[], Regressor] = LinearRegression,
|
||||
lb_regressor: Regressor = ScikitLearnRegressor(LinearRegression()),
|
||||
ub_regressor: Regressor = ScikitLearnRegressor(LinearRegression()),
|
||||
) -> None:
|
||||
assert isinstance(lb_regressor, Regressor)
|
||||
assert isinstance(ub_regressor, Regressor)
|
||||
self.ub_regressor: Optional[Regressor] = None
|
||||
self.lb_regressor: Optional[Regressor] = None
|
||||
self.lb_regressor_factory = lb_regressor
|
||||
self.ub_regressor_factory = ub_regressor
|
||||
self.lb_regressor_prototype = lb_regressor
|
||||
self.ub_regressor_prototype = ub_regressor
|
||||
self._predicted_ub: Optional[float] = None
|
||||
self._predicted_lb: Optional[float] = None
|
||||
|
||||
@@ -77,8 +80,8 @@ class ObjectiveValueComponent(Component):
|
||||
stats["Objective: predicted LB"] = self._predicted_lb
|
||||
|
||||
def fit(self, training_instances: Union[List[str], List[Instance]]) -> None:
|
||||
self.lb_regressor = self.lb_regressor_factory()
|
||||
self.ub_regressor = self.ub_regressor_factory()
|
||||
self.lb_regressor = self.lb_regressor_prototype.clone()
|
||||
self.ub_regressor = self.ub_regressor_prototype.clone()
|
||||
logger.debug("Extracting features...")
|
||||
x_train = self.x(training_instances)
|
||||
y_train = self.y(training_instances)
|
||||
|
||||
Reference in New Issue
Block a user