mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Sample: do not check data by default; minor fixes
This commit is contained in:
@@ -110,10 +110,12 @@ class MemorySample(Sample):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
data: Optional[Dict[str, Any]] = None,
|
data: Optional[Dict[str, Any]] = None,
|
||||||
|
check_data: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
if data is None:
|
if data is None:
|
||||||
data = {}
|
data = {}
|
||||||
self._data: Dict[str, Any] = data
|
self._data: Dict[str, Any] = data
|
||||||
|
self._check_data = check_data
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def get_bytes(self, key: str) -> Optional[Bytes]:
|
def get_bytes(self, key: str) -> Optional[Bytes]:
|
||||||
@@ -142,19 +144,22 @@ class MemorySample(Sample):
|
|||||||
def put_scalar(self, key: str, value: Scalar) -> None:
|
def put_scalar(self, key: str, value: Scalar) -> None:
|
||||||
if value is None:
|
if value is None:
|
||||||
return
|
return
|
||||||
self._assert_is_scalar(value)
|
if self._check_data:
|
||||||
|
self._assert_is_scalar(value)
|
||||||
self._put(key, value)
|
self._put(key, value)
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def put_vector(self, key: str, value: Vector) -> None:
|
def put_vector(self, key: str, value: Vector) -> None:
|
||||||
if value is None:
|
if value is None:
|
||||||
return
|
return
|
||||||
self._assert_is_vector(value)
|
if self._check_data:
|
||||||
|
self._assert_is_vector(value)
|
||||||
self._put(key, value)
|
self._put(key, value)
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def put_vector_list(self, key: str, value: VectorList) -> None:
|
def put_vector_list(self, key: str, value: VectorList) -> None:
|
||||||
self._assert_is_vector_list(value)
|
if self._check_data:
|
||||||
|
self._assert_is_vector_list(value)
|
||||||
self._put(key, value)
|
self._put(key, value)
|
||||||
|
|
||||||
def _get(self, key: str) -> Optional[Any]:
|
def _get(self, key: str) -> Optional[Any]:
|
||||||
@@ -175,8 +180,14 @@ class Hdf5Sample(Sample):
|
|||||||
are actually accessed, and therefore it is more scalable.
|
are actually accessed, and therefore it is more scalable.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, filename: str, mode: str = "r+") -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
filename: str,
|
||||||
|
mode: str = "r+",
|
||||||
|
check_data: bool = False,
|
||||||
|
) -> None:
|
||||||
self.file = h5py.File(filename, mode, libver="latest")
|
self.file = h5py.File(filename, mode, libver="latest")
|
||||||
|
self._check_data = check_data
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def get_bytes(self, key: str) -> Optional[Bytes]:
|
def get_bytes(self, key: str) -> Optional[Bytes]:
|
||||||
@@ -230,27 +241,30 @@ class Hdf5Sample(Sample):
|
|||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def put_bytes(self, key: str, value: Bytes) -> None:
|
def put_bytes(self, key: str, value: Bytes) -> None:
|
||||||
assert isinstance(
|
if self._check_data:
|
||||||
value, (bytes, bytearray)
|
assert isinstance(
|
||||||
), f"bytes expected; found: {value}" # type: ignore
|
value, (bytes, bytearray)
|
||||||
|
), f"bytes expected; found: {value}" # type: ignore
|
||||||
self._put(key, np.frombuffer(value, dtype="uint8"), compress=True)
|
self._put(key, np.frombuffer(value, dtype="uint8"), compress=True)
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def put_scalar(self, key: str, value: Any) -> None:
|
def put_scalar(self, key: str, value: Any) -> None:
|
||||||
if value is None:
|
if value is None:
|
||||||
return
|
return
|
||||||
self._assert_is_scalar(value)
|
if self._check_data:
|
||||||
|
self._assert_is_scalar(value)
|
||||||
self._put(key, value)
|
self._put(key, value)
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def put_vector(self, key: str, value: Vector) -> None:
|
def put_vector(self, key: str, value: Vector) -> None:
|
||||||
if value is None:
|
if value is None:
|
||||||
return
|
return
|
||||||
self._assert_is_vector(value)
|
if self._check_data:
|
||||||
|
self._assert_is_vector(value)
|
||||||
|
|
||||||
for v in value:
|
for v in value:
|
||||||
# Convert strings to bytes
|
# Convert strings to bytes
|
||||||
if isinstance(v, str):
|
if isinstance(v, str) or v is None:
|
||||||
value = np.array(
|
value = np.array(
|
||||||
[u if u is not None else b"" for u in value],
|
[u if u is not None else b"" for u in value],
|
||||||
dtype="S",
|
dtype="S",
|
||||||
@@ -266,7 +280,8 @@ class Hdf5Sample(Sample):
|
|||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def put_vector_list(self, key: str, value: VectorList) -> None:
|
def put_vector_list(self, key: str, value: VectorList) -> None:
|
||||||
self._assert_is_vector_list(value)
|
if self._check_data:
|
||||||
|
self._assert_is_vector_list(value)
|
||||||
padded, lens = _pad(value)
|
padded, lens = _pad(value)
|
||||||
self.put_vector(f"{key}_lengths", lens)
|
self.put_vector(f"{key}_lengths", lens)
|
||||||
data = None
|
data = None
|
||||||
@@ -297,7 +312,6 @@ class Hdf5Sample(Sample):
|
|||||||
|
|
||||||
|
|
||||||
def _pad(veclist: VectorList) -> Tuple[VectorList, List[int]]:
|
def _pad(veclist: VectorList) -> Tuple[VectorList, List[int]]:
|
||||||
veclist = deepcopy(veclist)
|
|
||||||
lens = [len(v) if v is not None else -1 for v in veclist]
|
lens = [len(v) if v is not None else -1 for v in veclist]
|
||||||
maxlen = max(lens)
|
maxlen = max(lens)
|
||||||
|
|
||||||
|
|||||||
@@ -2,13 +2,20 @@
|
|||||||
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
|
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
|
||||||
# Released under the modified BSD license. See COPYING.md for more details.
|
# Released under the modified BSD license. See COPYING.md for more details.
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import gurobipy as gp
|
||||||
|
|
||||||
from miplearn.features.extractor import FeaturesExtractor
|
from miplearn.features.extractor import FeaturesExtractor
|
||||||
from miplearn.features.sample import Sample, MemorySample
|
from miplearn.features.sample import MemorySample, Hdf5Sample
|
||||||
from miplearn.solvers.internal import Variables, Constraints
|
from miplearn.instance.base import Instance
|
||||||
from miplearn.solvers.gurobi import GurobiSolver
|
from miplearn.solvers.gurobi import GurobiSolver
|
||||||
|
from miplearn.solvers.internal import Variables, Constraints
|
||||||
from miplearn.solvers.tests import assert_equals
|
from miplearn.solvers.tests import assert_equals
|
||||||
|
import cProfile
|
||||||
|
|
||||||
inf = float("inf")
|
inf = float("inf")
|
||||||
|
|
||||||
@@ -166,3 +173,27 @@ def test_assert_equals() -> None:
|
|||||||
assert_equals(np.array([True, True]), [True, True])
|
assert_equals(np.array([True, True]), [True, True])
|
||||||
assert_equals((1.0,), (1.0,))
|
assert_equals((1.0,), (1.0,))
|
||||||
assert_equals({"x": 10}, {"x": 10})
|
assert_equals({"x": 10}, {"x": 10})
|
||||||
|
|
||||||
|
|
||||||
|
class MpsInstance(Instance):
|
||||||
|
def __init__(self, filename: str) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.filename = filename
|
||||||
|
|
||||||
|
def to_model(self) -> Any:
|
||||||
|
return gp.read(self.filename)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
solver = GurobiSolver()
|
||||||
|
instance = MpsInstance(sys.argv[1])
|
||||||
|
solver.set_instance(instance)
|
||||||
|
solver.solve_lp(tee=True)
|
||||||
|
extractor = FeaturesExtractor(with_lhs=False)
|
||||||
|
sample = Hdf5Sample("tmp/prof.h5", mode="w")
|
||||||
|
|
||||||
|
def run():
|
||||||
|
extractor.extract_after_load_features(instance, solver, sample)
|
||||||
|
extractor.extract_after_lp_features(solver, sample)
|
||||||
|
|
||||||
|
cProfile.run("run()", filename="tmp/prof")
|
||||||
|
|||||||
Reference in New Issue
Block a user