mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-10 11:28:51 -06:00
Enforce more overrides
This commit is contained in:
@@ -6,6 +6,7 @@ import logging
|
||||
from typing import List, Dict, Any, TYPE_CHECKING, Tuple, Hashable
|
||||
|
||||
import numpy as np
|
||||
from overrides import overrides
|
||||
from sklearn.linear_model import LinearRegression
|
||||
|
||||
from miplearn.classifiers import Regressor
|
||||
@@ -34,6 +35,7 @@ class ObjectiveValueComponent(Component):
|
||||
self.regressors: Dict[str, Regressor] = {}
|
||||
self.regressor_prototype = regressor
|
||||
|
||||
@overrides
|
||||
def before_solve_mip(
|
||||
self,
|
||||
solver: "LearningSolver",
|
||||
@@ -49,6 +51,7 @@ class ObjectiveValueComponent(Component):
|
||||
logger.info(f"Predicted {c.lower()}: %.6e" % v)
|
||||
stats[f"Objective: Predicted {c.lower()}"] = v # type: ignore
|
||||
|
||||
@overrides
|
||||
def fit_xy(
|
||||
self,
|
||||
x: Dict[Hashable, np.ndarray],
|
||||
@@ -73,6 +76,7 @@ class ObjectiveValueComponent(Component):
|
||||
logger.info(f"{c} regressor not fitted. Skipping.")
|
||||
return pred
|
||||
|
||||
@overrides
|
||||
def sample_xy(
|
||||
self,
|
||||
instance: Instance,
|
||||
@@ -94,6 +98,7 @@ class ObjectiveValueComponent(Component):
|
||||
y["Upper bound"] = [[sample.upper_bound]]
|
||||
return x, y
|
||||
|
||||
@overrides
|
||||
def sample_evaluate(
|
||||
self,
|
||||
instance: Instance,
|
||||
|
||||
Reference in New Issue
Block a user