AlvLouWeh2017: Remove sample argument

master
Alinson S. Xavier 4 years ago
parent a65ebfb17c
commit 256d3d094f
No known key found for this signature in database
GPG Key ID: DCA0DAD4D2F58624

@ -3,7 +3,7 @@
# Released under the modified BSD license. See COPYING.md for more details. # Released under the modified BSD license. See COPYING.md for more details.
from math import log, isfinite from math import log, isfinite
from typing import TYPE_CHECKING, List, Tuple from typing import TYPE_CHECKING, List, Tuple, Optional
import numpy as np import numpy as np
@ -74,7 +74,9 @@ class FeaturesExtractor:
np.hstack( np.hstack(
[ [
vars_features_user, vars_features_user,
self._extract_var_features_AlvLouWeh2017(sample), self._extract_var_features_AlvLouWeh2017(
obj_coeffs=variables.obj_coeffs,
),
variables.lower_bounds.reshape(-1, 1), variables.lower_bounds.reshape(-1, 1),
variables.obj_coeffs.reshape(-1, 1), variables.obj_coeffs.reshape(-1, 1),
variables.upper_bounds.reshape(-1, 1), variables.upper_bounds.reshape(-1, 1),
@ -111,7 +113,12 @@ class FeaturesExtractor:
lp_var_features_list = [] lp_var_features_list = []
for f in [ for f in [
sample.get_array("static_var_features"), sample.get_array("static_var_features"),
self._extract_var_features_AlvLouWeh2017(sample), self._extract_var_features_AlvLouWeh2017(
obj_coeffs=sample.get_array("static_var_obj_coeffs"),
obj_sa_up=variables.sa_obj_up,
obj_sa_down=variables.sa_obj_down,
values=variables.values,
),
]: ]:
if f is not None: if f is not None:
lp_var_features_list.append(f) lp_var_features_list.append(f)
@ -310,12 +317,14 @@ class FeaturesExtractor:
# Alvarez, A. M., Louveaux, Q., & Wehenkel, L. (2017). A machine learning-based # Alvarez, A. M., Louveaux, Q., & Wehenkel, L. (2017). A machine learning-based
# approximation of strong branching. INFORMS Journal on Computing, 29(1), 185-195. # approximation of strong branching. INFORMS Journal on Computing, 29(1), 185-195.
def _extract_var_features_AlvLouWeh2017(self, sample: Sample) -> np.ndarray: # noinspection PyPep8Naming
obj_coeffs = sample.get_array("static_var_obj_coeffs") def _extract_var_features_AlvLouWeh2017(
obj_sa_down = sample.get_array("lp_var_sa_obj_down") self,
obj_sa_up = sample.get_array("lp_var_sa_obj_up") obj_coeffs: Optional[np.ndarray] = None,
values = sample.get_array("lp_var_values") obj_sa_down: Optional[np.ndarray] = None,
obj_sa_up: Optional[np.ndarray] = None,
values: Optional[np.ndarray] = None,
) -> np.ndarray:
assert obj_coeffs is not None assert obj_coeffs is not None
obj_coeffs = obj_coeffs.astype(float) obj_coeffs = obj_coeffs.astype(float)
_fix_infinity(obj_coeffs) _fix_infinity(obj_coeffs)

Loading…
Cancel
Save