Compare commits

...

3 Commits

Binary file not shown.

@ -81,7 +81,6 @@ class BasicCollector:
print(f"Error processing: data_filename") print(f"Error processing: data_filename")
traceback.print_exc() traceback.print_exc()
if n_jobs > 1: if n_jobs > 1:
p_umap( p_umap(
_collect, _collect,

@ -1,29 +1,53 @@
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization # MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
# Copyright (C) 2020-2022, UChicago Argonne, LLC. All rights reserved. # Copyright (C) 2020-2022, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details. # Released under the modified BSD license. See COPYING.md for more details.
from typing import Tuple from typing import Tuple, List
import numpy as np import numpy as np
from miplearn.h5 import H5File from miplearn.h5 import H5File
def _extract_bin_var_names_values( def _extract_var_names_values(
h5: H5File, h5: H5File,
selected_var_types: List[bytes],
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
bin_var_names, bin_var_indices = _extract_bin_var_names(h5) bin_var_names, bin_var_indices = _extract_var_names(h5, selected_var_types)
var_values = h5.get_array("mip_var_values") var_values = h5.get_array("mip_var_values")
assert var_values is not None assert var_values is not None
bin_var_values = var_values[bin_var_indices].astype(int) bin_var_values = var_values[bin_var_indices].astype(int)
return bin_var_names, bin_var_values, bin_var_indices return bin_var_names, bin_var_values, bin_var_indices
def _extract_bin_var_names(h5: H5File) -> Tuple[np.ndarray, np.ndarray]: def _extract_var_names(
h5: H5File,
selected_var_types: List[bytes],
) -> Tuple[np.ndarray, np.ndarray]:
var_types = h5.get_array("static_var_types") var_types = h5.get_array("static_var_types")
var_names = h5.get_array("static_var_names") var_names = h5.get_array("static_var_names")
assert var_types is not None assert var_types is not None
assert var_names is not None assert var_names is not None
bin_var_indices = np.where(var_types == b"B")[0] bin_var_indices = np.where(np.isin(var_types, selected_var_types))[0]
bin_var_names = var_names[bin_var_indices] bin_var_names = var_names[bin_var_indices]
assert len(bin_var_names.shape) == 1 assert len(bin_var_names.shape) == 1
return bin_var_names, bin_var_indices return bin_var_names, bin_var_indices
def _extract_bin_var_names_values(
h5: H5File,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
return _extract_var_names_values(h5, [b"B"])
def _extract_bin_var_names(h5: H5File) -> Tuple[np.ndarray, np.ndarray]:
return _extract_var_names(h5, [b"B"])
def _extract_int_var_names_values(
h5: H5File,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
return _extract_var_names_values(h5, [b"B", b"I"])
def _extract_int_var_names(h5: H5File) -> Tuple[np.ndarray, np.ndarray]:
return _extract_var_names(h5, [b"B", b"I"])

@ -5,7 +5,7 @@
import logging import logging
from typing import Any, Dict, List from typing import Any, Dict, List
from . import _extract_bin_var_names_values from . import _extract_int_var_names_values
from .actions import PrimalComponentAction from .actions import PrimalComponentAction
from ...solvers.abstract import AbstractModel from ...solvers.abstract import AbstractModel
from ...h5 import H5File from ...h5 import H5File
@ -28,5 +28,5 @@ class ExpertPrimalComponent:
self, test_h5: str, model: AbstractModel, stats: Dict[str, Any] self, test_h5: str, model: AbstractModel, stats: Dict[str, Any]
) -> None: ) -> None:
with H5File(test_h5, "r") as h5: with H5File(test_h5, "r") as h5:
names, values, _ = _extract_bin_var_names_values(h5) names, values, _ = _extract_int_var_names_values(h5)
self.action.perform(model, names, values.reshape(1, -1), stats) self.action.perform(model, names, values.reshape(1, -1), stats)

@ -68,7 +68,7 @@ class H5File:
return return
self._assert_is_array(value) self._assert_is_array(value)
if value.dtype.kind == "f": if value.dtype.kind == "f":
value = value.astype("float32") value = value.astype("float64")
if key in self.file: if key in self.file:
del self.file[key] del self.file[key]
return self.file.create_dataset(key, data=value, compression="gzip") return self.file.create_dataset(key, data=value, compression="gzip")

@ -3,7 +3,7 @@
# Released under the modified BSD license. See COPYING.md for more details. # Released under the modified BSD license. See COPYING.md for more details.
from os.path import exists from os.path import exists
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from typing import List, Any, Union, Dict, Callable, Optional from typing import List, Any, Union, Dict, Callable, Optional, Tuple
from miplearn.h5 import H5File from miplearn.h5 import H5File
from miplearn.io import _to_h5_filename from miplearn.io import _to_h5_filename
@ -25,7 +25,7 @@ class LearningSolver:
self, self,
model: Union[str, AbstractModel], model: Union[str, AbstractModel],
build_model: Optional[Callable] = None, build_model: Optional[Callable] = None,
) -> Dict[str, Any]: ) -> Tuple[AbstractModel, Dict[str, Any]]:
h5_filename, mode = NamedTemporaryFile().name, "w" h5_filename, mode = NamedTemporaryFile().name, "w"
if isinstance(model, str): if isinstance(model, str):
assert build_model is not None assert build_model is not None
@ -51,4 +51,4 @@ class LearningSolver:
model.optimize() model.optimize()
model.extract_after_mip(h5) model.extract_after_mip(h5)
return stats return model, stats

@ -71,5 +71,5 @@ def test_usage_stab(
comp = MemorizingCutsComponent(clf=clf, extractor=default_extractor) comp = MemorizingCutsComponent(clf=clf, extractor=default_extractor)
solver = LearningSolver(components=[comp]) solver = LearningSolver(components=[comp])
solver.fit(data_filenames) solver.fit(data_filenames)
stats = solver.optimize(data_filenames[0], build_model) # type: ignore model, stats = solver.optimize(data_filenames[0], build_model) # type: ignore
assert stats["Cuts: AOT"] > 0 assert stats["Cuts: AOT"] > 0

@ -65,5 +65,5 @@ def test_usage_tsp(
comp = MemorizingLazyComponent(clf=clf, extractor=default_extractor) comp = MemorizingLazyComponent(clf=clf, extractor=default_extractor)
solver = LearningSolver(components=[comp]) solver = LearningSolver(components=[comp])
solver.fit(data_filenames) solver.fit(data_filenames)
stats = solver.optimize(data_filenames[0], build_model) # type: ignore model, stats = solver.optimize(data_filenames[0], build_model) # type: ignore
assert stats["Lazy Constraints: AOT"] > 0 assert stats["Lazy Constraints: AOT"] > 0

Loading…
Cancel
Save