diff --git a/miplearn/collectors/basic.py b/miplearn/collectors/basic.py index 9ed6b03..a5dac1b 100644 --- a/miplearn/collectors/basic.py +++ b/miplearn/collectors/basic.py @@ -81,7 +81,6 @@ class BasicCollector: print(f"Error processing: data_filename") traceback.print_exc() - if n_jobs > 1: p_umap( _collect, diff --git a/miplearn/components/primal/__init__.py b/miplearn/components/primal/__init__.py index 4c8a414..cac8d87 100644 --- a/miplearn/components/primal/__init__.py +++ b/miplearn/components/primal/__init__.py @@ -1,29 +1,53 @@ # MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization # Copyright (C) 2020-2022, UChicago Argonne, LLC. All rights reserved. # Released under the modified BSD license. See COPYING.md for more details. -from typing import Tuple +from typing import Tuple, List import numpy as np from miplearn.h5 import H5File -def _extract_bin_var_names_values( +def _extract_var_names_values( h5: H5File, + selected_var_types: List[bytes], ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - bin_var_names, bin_var_indices = _extract_bin_var_names(h5) + bin_var_names, bin_var_indices = _extract_var_names(h5, selected_var_types) var_values = h5.get_array("mip_var_values") assert var_values is not None bin_var_values = var_values[bin_var_indices].astype(int) return bin_var_names, bin_var_values, bin_var_indices -def _extract_bin_var_names(h5: H5File) -> Tuple[np.ndarray, np.ndarray]: +def _extract_var_names( + h5: H5File, + selected_var_types: List[bytes], +) -> Tuple[np.ndarray, np.ndarray]: var_types = h5.get_array("static_var_types") var_names = h5.get_array("static_var_names") assert var_types is not None assert var_names is not None - bin_var_indices = np.where(var_types == b"B")[0] + bin_var_indices = np.where(np.isin(var_types, selected_var_types))[0] bin_var_names = var_names[bin_var_indices] assert len(bin_var_names.shape) == 1 return bin_var_names, bin_var_indices + + +def _extract_bin_var_names_values( + h5: H5File, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + return _extract_var_names_values(h5, [b"B"]) + + +def _extract_bin_var_names(h5: H5File) -> Tuple[np.ndarray, np.ndarray]: + return _extract_var_names(h5, [b"B"]) + + +def _extract_int_var_names_values( + h5: H5File, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + return _extract_var_names_values(h5, [b"B", b"I"]) + + +def _extract_int_var_names(h5: H5File) -> Tuple[np.ndarray, np.ndarray]: + return _extract_var_names(h5, [b"B", b"I"]) diff --git a/miplearn/components/primal/expert.py b/miplearn/components/primal/expert.py index 8bf4319..871a9a5 100644 --- a/miplearn/components/primal/expert.py +++ b/miplearn/components/primal/expert.py @@ -5,7 +5,7 @@ import logging from typing import Any, Dict, List -from . import _extract_bin_var_names_values +from . import _extract_int_var_names_values from .actions import PrimalComponentAction from ...solvers.abstract import AbstractModel from ...h5 import H5File @@ -28,5 +28,5 @@ class ExpertPrimalComponent: self, test_h5: str, model: AbstractModel, stats: Dict[str, Any] ) -> None: with H5File(test_h5, "r") as h5: - names, values, _ = _extract_bin_var_names_values(h5) + names, values, _ = _extract_int_var_names_values(h5) self.action.perform(model, names, values.reshape(1, -1), stats)