Implement sample.{get,put}_sparse

master
Alinson S. Xavier 4 years ago
parent 5b54153a3a
commit 63eff336e2
No known key found for this signature in database
GPG Key ID: DCA0DAD4D2F58624

@ -5,6 +5,7 @@ import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from copy import deepcopy from copy import deepcopy
from typing import Dict, Optional, Any, Union, List, Tuple, cast, Set from typing import Dict, Optional, Any, Union, List, Tuple, cast, Set
from scipy.sparse import coo_matrix
import h5py import h5py
import numpy as np import numpy as np
@ -80,6 +81,14 @@ class Sample(ABC):
def get_array(self, key: str) -> Optional[np.ndarray]: def get_array(self, key: str) -> Optional[np.ndarray]:
pass pass
@abstractmethod
def put_sparse(self, key: str, value: coo_matrix) -> None:
pass
@abstractmethod
def get_sparse(self, key: str) -> Optional[coo_matrix]:
pass
def get_set(self, key: str) -> Set: def get_set(self, key: str) -> Set:
v = self.get_vector(key) v = self.get_vector(key)
if v: if v:
@ -118,6 +127,10 @@ class Sample(ABC):
assert isinstance(value, np.ndarray) assert isinstance(value, np.ndarray)
assert value.dtype.kind in "biufS", f"Unsupported dtype: {value.dtype}" assert value.dtype.kind in "biufS", f"Unsupported dtype: {value.dtype}"
def _assert_is_sparse(self, value: Any) -> None:
assert isinstance(value, coo_matrix)
self._assert_supported(value.data)
class MemorySample(Sample): class MemorySample(Sample):
"""Dictionary-like class that stores training data in-memory.""" """Dictionary-like class that stores training data in-memory."""
@ -197,6 +210,17 @@ class MemorySample(Sample):
def get_array(self, key: str) -> Optional[np.ndarray]: def get_array(self, key: str) -> Optional[np.ndarray]:
return cast(Optional[np.ndarray], self._get(key)) return cast(Optional[np.ndarray], self._get(key))
@overrides
def put_sparse(self, key: str, value: coo_matrix) -> None:
if value is None:
return
self._assert_is_sparse(value)
self._put(key, value)
@overrides
def get_sparse(self, key: str) -> Optional[coo_matrix]:
return cast(Optional[coo_matrix], self._get(key))
class Hdf5Sample(Sample): class Hdf5Sample(Sample):
""" """
@ -351,6 +375,26 @@ class Hdf5Sample(Sample):
return None return None
return self.file[key][:] return self.file[key][:]
@overrides
def put_sparse(self, key: str, value: coo_matrix) -> None:
if value is None:
return
self._assert_is_sparse(value)
self.put_array(f"{key}_row", value.row)
self.put_array(f"{key}_col", value.col)
self.put_array(f"{key}_data", value.data)
@overrides
def get_sparse(self, key: str) -> Optional[coo_matrix]:
row = self.get_array(f"{key}_row")
if row is None:
return None
col = self.get_array(f"{key}_col")
data = self.get_array(f"{key}_data")
assert col is not None
assert data is not None
return coo_matrix((data, (row, col)))
def _pad(veclist: VectorList) -> Tuple[VectorList, List[int]]: def _pad(veclist: VectorList) -> Tuple[VectorList, List[int]]:
veclist = deepcopy(veclist) veclist = deepcopy(veclist)

@ -5,6 +5,7 @@ from tempfile import NamedTemporaryFile
from typing import Any from typing import Any
import numpy as np import numpy as np
from scipy.sparse import coo_matrix
from miplearn.features.sample import MemorySample, Sample, Hdf5Sample from miplearn.features.sample import MemorySample, Sample, Hdf5Sample
@ -23,6 +24,8 @@ def _test_sample(sample: Sample) -> None:
_assert_roundtrip_scalar(sample, True) _assert_roundtrip_scalar(sample, True)
_assert_roundtrip_scalar(sample, 1) _assert_roundtrip_scalar(sample, 1)
_assert_roundtrip_scalar(sample, 1.0) _assert_roundtrip_scalar(sample, 1.0)
assert sample.get_scalar("unknown-key") is None
_assert_roundtrip_array(sample, np.array([True, False], dtype="bool")) _assert_roundtrip_array(sample, np.array([True, False], dtype="bool"))
_assert_roundtrip_array(sample, np.array([1, 2, 3], dtype="int16")) _assert_roundtrip_array(sample, np.array([1, 2, 3], dtype="int16"))
_assert_roundtrip_array(sample, np.array([1, 2, 3], dtype="int32")) _assert_roundtrip_array(sample, np.array([1, 2, 3], dtype="int32"))
@ -31,24 +34,45 @@ def _test_sample(sample: Sample) -> None:
_assert_roundtrip_array(sample, np.array([1.0, 2.0, 3.0], dtype="float32")) _assert_roundtrip_array(sample, np.array([1.0, 2.0, 3.0], dtype="float32"))
_assert_roundtrip_array(sample, np.array([1.0, 2.0, 3.0], dtype="float64")) _assert_roundtrip_array(sample, np.array([1.0, 2.0, 3.0], dtype="float64"))
_assert_roundtrip_array(sample, np.array(["A", "BB", "CCC"], dtype="S")) _assert_roundtrip_array(sample, np.array(["A", "BB", "CCC"], dtype="S"))
assert sample.get_scalar("unknown-key") is None
assert sample.get_array("unknown-key") is None assert sample.get_array("unknown-key") is None
_assert_roundtrip_sparse(
sample,
coo_matrix(
[
[1, 0, 0],
[0, 2, 3],
[0, 0, 4],
],
dtype=float,
),
)
assert sample.get_sparse("unknown-key") is None
def _assert_roundtrip_array(sample: Sample, expected: Any) -> None:
sample.put_array("key", expected)
actual = sample.get_array("key")
assert actual is not None
assert isinstance(actual, np.ndarray)
assert actual.dtype == expected.dtype
assert (actual == expected).all()
def _assert_roundtrip_array(sample: Sample, original: np.ndarray) -> None:
sample.put_array("key", original)
recovered = sample.get_array("key")
assert recovered is not None
assert isinstance(recovered, np.ndarray)
assert recovered.dtype == original.dtype
assert (recovered == original).all()
def _assert_roundtrip_scalar(sample: Sample, expected: Any) -> None:
sample.put_scalar("key", expected) def _assert_roundtrip_scalar(sample: Sample, original: Any) -> None:
actual = sample.get_scalar("key") sample.put_scalar("key", original)
assert actual == expected recovered = sample.get_scalar("key")
assert actual is not None assert recovered == original
assert recovered is not None
assert isinstance( assert isinstance(
actual, expected.__class__ recovered, original.__class__
), f"Expected {expected.__class__}, found {actual.__class__} instead" ), f"Expected {original.__class__}, found {recovered.__class__} instead"
def _assert_roundtrip_sparse(sample: Sample, original: coo_matrix) -> None:
sample.put_sparse("key", original)
recovered = sample.get_sparse("key")
assert recovered is not None
assert isinstance(recovered, coo_matrix)
assert recovered.dtype == original.dtype
assert (original != recovered).sum() == 0

Loading…
Cancel
Save