Sample: do not check data by default; minor fixes

master
Alinson S. Xavier 4 years ago
parent 95b9ce29fd
commit 475fe3d985

@ -110,10 +110,12 @@ class MemorySample(Sample):
def __init__( def __init__(
self, self,
data: Optional[Dict[str, Any]] = None, data: Optional[Dict[str, Any]] = None,
check_data: bool = False,
) -> None: ) -> None:
if data is None: if data is None:
data = {} data = {}
self._data: Dict[str, Any] = data self._data: Dict[str, Any] = data
self._check_data = check_data
@overrides @overrides
def get_bytes(self, key: str) -> Optional[Bytes]: def get_bytes(self, key: str) -> Optional[Bytes]:
@ -142,6 +144,7 @@ class MemorySample(Sample):
def put_scalar(self, key: str, value: Scalar) -> None: def put_scalar(self, key: str, value: Scalar) -> None:
if value is None: if value is None:
return return
if self._check_data:
self._assert_is_scalar(value) self._assert_is_scalar(value)
self._put(key, value) self._put(key, value)
@ -149,11 +152,13 @@ class MemorySample(Sample):
def put_vector(self, key: str, value: Vector) -> None: def put_vector(self, key: str, value: Vector) -> None:
if value is None: if value is None:
return return
if self._check_data:
self._assert_is_vector(value) self._assert_is_vector(value)
self._put(key, value) self._put(key, value)
@overrides @overrides
def put_vector_list(self, key: str, value: VectorList) -> None: def put_vector_list(self, key: str, value: VectorList) -> None:
if self._check_data:
self._assert_is_vector_list(value) self._assert_is_vector_list(value)
self._put(key, value) self._put(key, value)
@ -175,8 +180,14 @@ class Hdf5Sample(Sample):
are actually accessed, and therefore it is more scalable. are actually accessed, and therefore it is more scalable.
""" """
def __init__(self, filename: str, mode: str = "r+") -> None: def __init__(
self,
filename: str,
mode: str = "r+",
check_data: bool = False,
) -> None:
self.file = h5py.File(filename, mode, libver="latest") self.file = h5py.File(filename, mode, libver="latest")
self._check_data = check_data
@overrides @overrides
def get_bytes(self, key: str) -> Optional[Bytes]: def get_bytes(self, key: str) -> Optional[Bytes]:
@ -230,6 +241,7 @@ class Hdf5Sample(Sample):
@overrides @overrides
def put_bytes(self, key: str, value: Bytes) -> None: def put_bytes(self, key: str, value: Bytes) -> None:
if self._check_data:
assert isinstance( assert isinstance(
value, (bytes, bytearray) value, (bytes, bytearray)
), f"bytes expected; found: {value}" # type: ignore ), f"bytes expected; found: {value}" # type: ignore
@ -239,6 +251,7 @@ class Hdf5Sample(Sample):
def put_scalar(self, key: str, value: Any) -> None: def put_scalar(self, key: str, value: Any) -> None:
if value is None: if value is None:
return return
if self._check_data:
self._assert_is_scalar(value) self._assert_is_scalar(value)
self._put(key, value) self._put(key, value)
@ -246,11 +259,12 @@ class Hdf5Sample(Sample):
def put_vector(self, key: str, value: Vector) -> None: def put_vector(self, key: str, value: Vector) -> None:
if value is None: if value is None:
return return
if self._check_data:
self._assert_is_vector(value) self._assert_is_vector(value)
for v in value: for v in value:
# Convert strings to bytes # Convert strings to bytes
if isinstance(v, str): if isinstance(v, str) or v is None:
value = np.array( value = np.array(
[u if u is not None else b"" for u in value], [u if u is not None else b"" for u in value],
dtype="S", dtype="S",
@ -266,6 +280,7 @@ class Hdf5Sample(Sample):
@overrides @overrides
def put_vector_list(self, key: str, value: VectorList) -> None: def put_vector_list(self, key: str, value: VectorList) -> None:
if self._check_data:
self._assert_is_vector_list(value) self._assert_is_vector_list(value)
padded, lens = _pad(value) padded, lens = _pad(value)
self.put_vector(f"{key}_lengths", lens) self.put_vector(f"{key}_lengths", lens)
@ -297,7 +312,6 @@ class Hdf5Sample(Sample):
def _pad(veclist: VectorList) -> Tuple[VectorList, List[int]]: def _pad(veclist: VectorList) -> Tuple[VectorList, List[int]]:
veclist = deepcopy(veclist)
lens = [len(v) if v is not None else -1 for v in veclist] lens = [len(v) if v is not None else -1 for v in veclist]
maxlen = max(lens) maxlen = max(lens)

@ -2,13 +2,20 @@
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved. # Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details. # Released under the modified BSD license. See COPYING.md for more details.
import sys
import time
from typing import Any
import numpy as np import numpy as np
import gurobipy as gp
from miplearn.features.extractor import FeaturesExtractor from miplearn.features.extractor import FeaturesExtractor
from miplearn.features.sample import Sample, MemorySample from miplearn.features.sample import MemorySample, Hdf5Sample
from miplearn.solvers.internal import Variables, Constraints from miplearn.instance.base import Instance
from miplearn.solvers.gurobi import GurobiSolver from miplearn.solvers.gurobi import GurobiSolver
from miplearn.solvers.internal import Variables, Constraints
from miplearn.solvers.tests import assert_equals from miplearn.solvers.tests import assert_equals
import cProfile
inf = float("inf") inf = float("inf")
@ -166,3 +173,27 @@ def test_assert_equals() -> None:
assert_equals(np.array([True, True]), [True, True]) assert_equals(np.array([True, True]), [True, True])
assert_equals((1.0,), (1.0,)) assert_equals((1.0,), (1.0,))
assert_equals({"x": 10}, {"x": 10}) assert_equals({"x": 10}, {"x": 10})
class MpsInstance(Instance):
def __init__(self, filename: str) -> None:
super().__init__()
self.filename = filename
def to_model(self) -> Any:
return gp.read(self.filename)
if __name__ == "__main__":
solver = GurobiSolver()
instance = MpsInstance(sys.argv[1])
solver.set_instance(instance)
solver.solve_lp(tee=True)
extractor = FeaturesExtractor(with_lhs=False)
sample = Hdf5Sample("tmp/prof.h5", mode="w")
def run():
extractor.extract_after_load_features(instance, solver, sample)
extractor.extract_after_lp_features(solver, sample)
cProfile.run("run()", filename="tmp/prof")

Loading…
Cancel
Save