AlvLouWeh2017: Remove sample argument

master
Alinson S. Xavier 4 years ago
parent a65ebfb17c
commit 256d3d094f
No known key found for this signature in database
GPG Key ID: DCA0DAD4D2F58624

@ -3,7 +3,7 @@
# Released under the modified BSD license. See COPYING.md for more details.
from math import log, isfinite
from typing import TYPE_CHECKING, List, Tuple
from typing import TYPE_CHECKING, List, Tuple, Optional
import numpy as np
@ -74,7 +74,9 @@ class FeaturesExtractor:
np.hstack(
[
vars_features_user,
self._extract_var_features_AlvLouWeh2017(sample),
self._extract_var_features_AlvLouWeh2017(
obj_coeffs=variables.obj_coeffs,
),
variables.lower_bounds.reshape(-1, 1),
variables.obj_coeffs.reshape(-1, 1),
variables.upper_bounds.reshape(-1, 1),
@ -111,7 +113,12 @@ class FeaturesExtractor:
lp_var_features_list = []
for f in [
sample.get_array("static_var_features"),
self._extract_var_features_AlvLouWeh2017(sample),
self._extract_var_features_AlvLouWeh2017(
obj_coeffs=sample.get_array("static_var_obj_coeffs"),
obj_sa_up=variables.sa_obj_up,
obj_sa_down=variables.sa_obj_down,
values=variables.values,
),
]:
if f is not None:
lp_var_features_list.append(f)
@ -310,12 +317,14 @@ class FeaturesExtractor:
# Alvarez, A. M., Louveaux, Q., & Wehenkel, L. (2017). A machine learning-based
# approximation of strong branching. INFORMS Journal on Computing, 29(1), 185-195.
def _extract_var_features_AlvLouWeh2017(self, sample: Sample) -> np.ndarray:
obj_coeffs = sample.get_array("static_var_obj_coeffs")
obj_sa_down = sample.get_array("lp_var_sa_obj_down")
obj_sa_up = sample.get_array("lp_var_sa_obj_up")
values = sample.get_array("lp_var_values")
# noinspection PyPep8Naming
def _extract_var_features_AlvLouWeh2017(
self,
obj_coeffs: Optional[np.ndarray] = None,
obj_sa_down: Optional[np.ndarray] = None,
obj_sa_up: Optional[np.ndarray] = None,
values: Optional[np.ndarray] = None,
) -> np.ndarray:
assert obj_coeffs is not None
obj_coeffs = obj_coeffs.astype(float)
_fix_infinity(obj_coeffs)

Loading…
Cancel
Save