mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
expert primal: Set value for int variables
This commit is contained in:
@@ -81,7 +81,6 @@ class BasicCollector:
|
|||||||
print(f"Error processing: data_filename")
|
print(f"Error processing: data_filename")
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
if n_jobs > 1:
|
if n_jobs > 1:
|
||||||
p_umap(
|
p_umap(
|
||||||
_collect,
|
_collect,
|
||||||
|
|||||||
@@ -1,29 +1,53 @@
|
|||||||
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
||||||
# Copyright (C) 2020-2022, UChicago Argonne, LLC. All rights reserved.
|
# Copyright (C) 2020-2022, UChicago Argonne, LLC. All rights reserved.
|
||||||
# Released under the modified BSD license. See COPYING.md for more details.
|
# Released under the modified BSD license. See COPYING.md for more details.
|
||||||
from typing import Tuple
|
from typing import Tuple, List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from miplearn.h5 import H5File
|
from miplearn.h5 import H5File
|
||||||
|
|
||||||
|
|
||||||
def _extract_bin_var_names_values(
|
def _extract_var_names_values(
|
||||||
h5: H5File,
|
h5: H5File,
|
||||||
|
selected_var_types: List[bytes],
|
||||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||||
bin_var_names, bin_var_indices = _extract_bin_var_names(h5)
|
bin_var_names, bin_var_indices = _extract_var_names(h5, selected_var_types)
|
||||||
var_values = h5.get_array("mip_var_values")
|
var_values = h5.get_array("mip_var_values")
|
||||||
assert var_values is not None
|
assert var_values is not None
|
||||||
bin_var_values = var_values[bin_var_indices].astype(int)
|
bin_var_values = var_values[bin_var_indices].astype(int)
|
||||||
return bin_var_names, bin_var_values, bin_var_indices
|
return bin_var_names, bin_var_values, bin_var_indices
|
||||||
|
|
||||||
|
|
||||||
def _extract_bin_var_names(h5: H5File) -> Tuple[np.ndarray, np.ndarray]:
|
def _extract_var_names(
|
||||||
|
h5: H5File,
|
||||||
|
selected_var_types: List[bytes],
|
||||||
|
) -> Tuple[np.ndarray, np.ndarray]:
|
||||||
var_types = h5.get_array("static_var_types")
|
var_types = h5.get_array("static_var_types")
|
||||||
var_names = h5.get_array("static_var_names")
|
var_names = h5.get_array("static_var_names")
|
||||||
assert var_types is not None
|
assert var_types is not None
|
||||||
assert var_names is not None
|
assert var_names is not None
|
||||||
bin_var_indices = np.where(var_types == b"B")[0]
|
bin_var_indices = np.where(np.isin(var_types, selected_var_types))[0]
|
||||||
bin_var_names = var_names[bin_var_indices]
|
bin_var_names = var_names[bin_var_indices]
|
||||||
assert len(bin_var_names.shape) == 1
|
assert len(bin_var_names.shape) == 1
|
||||||
return bin_var_names, bin_var_indices
|
return bin_var_names, bin_var_indices
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_bin_var_names_values(
|
||||||
|
h5: H5File,
|
||||||
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||||
|
return _extract_var_names_values(h5, [b"B"])
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_bin_var_names(h5: H5File) -> Tuple[np.ndarray, np.ndarray]:
|
||||||
|
return _extract_var_names(h5, [b"B"])
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_int_var_names_values(
|
||||||
|
h5: H5File,
|
||||||
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||||
|
return _extract_var_names_values(h5, [b"B", b"I"])
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_int_var_names(h5: H5File) -> Tuple[np.ndarray, np.ndarray]:
|
||||||
|
return _extract_var_names(h5, [b"B", b"I"])
|
||||||
|
|||||||
@@ -5,7 +5,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from . import _extract_bin_var_names_values
|
from . import _extract_int_var_names_values
|
||||||
from .actions import PrimalComponentAction
|
from .actions import PrimalComponentAction
|
||||||
from ...solvers.abstract import AbstractModel
|
from ...solvers.abstract import AbstractModel
|
||||||
from ...h5 import H5File
|
from ...h5 import H5File
|
||||||
@@ -28,5 +28,5 @@ class ExpertPrimalComponent:
|
|||||||
self, test_h5: str, model: AbstractModel, stats: Dict[str, Any]
|
self, test_h5: str, model: AbstractModel, stats: Dict[str, Any]
|
||||||
) -> None:
|
) -> None:
|
||||||
with H5File(test_h5, "r") as h5:
|
with H5File(test_h5, "r") as h5:
|
||||||
names, values, _ = _extract_bin_var_names_values(h5)
|
names, values, _ = _extract_int_var_names_values(h5)
|
||||||
self.action.perform(model, names, values.reshape(1, -1), stats)
|
self.action.perform(model, names, values.reshape(1, -1), stats)
|
||||||
|
|||||||
Reference in New Issue
Block a user