Implement sample.{get,put}_sparse

This commit is contained in:
2021-08-09 07:09:02 -05:00
parent 5b54153a3a
commit 63eff336e2
2 changed files with 84 additions and 16 deletions

View File

@@ -5,6 +5,7 @@ import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from copy import deepcopy from copy import deepcopy
from typing import Dict, Optional, Any, Union, List, Tuple, cast, Set from typing import Dict, Optional, Any, Union, List, Tuple, cast, Set
from scipy.sparse import coo_matrix
import h5py import h5py
import numpy as np import numpy as np
@@ -80,6 +81,14 @@ class Sample(ABC):
def get_array(self, key: str) -> Optional[np.ndarray]: def get_array(self, key: str) -> Optional[np.ndarray]:
pass pass
@abstractmethod
def put_sparse(self, key: str, value: coo_matrix) -> None:
pass
@abstractmethod
def get_sparse(self, key: str) -> Optional[coo_matrix]:
pass
def get_set(self, key: str) -> Set: def get_set(self, key: str) -> Set:
v = self.get_vector(key) v = self.get_vector(key)
if v: if v:
@@ -118,6 +127,10 @@ class Sample(ABC):
assert isinstance(value, np.ndarray) assert isinstance(value, np.ndarray)
assert value.dtype.kind in "biufS", f"Unsupported dtype: {value.dtype}" assert value.dtype.kind in "biufS", f"Unsupported dtype: {value.dtype}"
def _assert_is_sparse(self, value: Any) -> None:
assert isinstance(value, coo_matrix)
self._assert_supported(value.data)
class MemorySample(Sample): class MemorySample(Sample):
"""Dictionary-like class that stores training data in-memory.""" """Dictionary-like class that stores training data in-memory."""
@@ -197,6 +210,17 @@ class MemorySample(Sample):
def get_array(self, key: str) -> Optional[np.ndarray]: def get_array(self, key: str) -> Optional[np.ndarray]:
return cast(Optional[np.ndarray], self._get(key)) return cast(Optional[np.ndarray], self._get(key))
@overrides
def put_sparse(self, key: str, value: coo_matrix) -> None:
if value is None:
return
self._assert_is_sparse(value)
self._put(key, value)
@overrides
def get_sparse(self, key: str) -> Optional[coo_matrix]:
return cast(Optional[coo_matrix], self._get(key))
class Hdf5Sample(Sample): class Hdf5Sample(Sample):
""" """
@@ -351,6 +375,26 @@ class Hdf5Sample(Sample):
return None return None
return self.file[key][:] return self.file[key][:]
@overrides
def put_sparse(self, key: str, value: coo_matrix) -> None:
if value is None:
return
self._assert_is_sparse(value)
self.put_array(f"{key}_row", value.row)
self.put_array(f"{key}_col", value.col)
self.put_array(f"{key}_data", value.data)
@overrides
def get_sparse(self, key: str) -> Optional[coo_matrix]:
row = self.get_array(f"{key}_row")
if row is None:
return None
col = self.get_array(f"{key}_col")
data = self.get_array(f"{key}_data")
assert col is not None
assert data is not None
return coo_matrix((data, (row, col)))
def _pad(veclist: VectorList) -> Tuple[VectorList, List[int]]: def _pad(veclist: VectorList) -> Tuple[VectorList, List[int]]:
veclist = deepcopy(veclist) veclist = deepcopy(veclist)

View File

@@ -5,6 +5,7 @@ from tempfile import NamedTemporaryFile
from typing import Any from typing import Any
import numpy as np import numpy as np
from scipy.sparse import coo_matrix
from miplearn.features.sample import MemorySample, Sample, Hdf5Sample from miplearn.features.sample import MemorySample, Sample, Hdf5Sample
@@ -23,6 +24,8 @@ def _test_sample(sample: Sample) -> None:
_assert_roundtrip_scalar(sample, True) _assert_roundtrip_scalar(sample, True)
_assert_roundtrip_scalar(sample, 1) _assert_roundtrip_scalar(sample, 1)
_assert_roundtrip_scalar(sample, 1.0) _assert_roundtrip_scalar(sample, 1.0)
assert sample.get_scalar("unknown-key") is None
_assert_roundtrip_array(sample, np.array([True, False], dtype="bool")) _assert_roundtrip_array(sample, np.array([True, False], dtype="bool"))
_assert_roundtrip_array(sample, np.array([1, 2, 3], dtype="int16")) _assert_roundtrip_array(sample, np.array([1, 2, 3], dtype="int16"))
_assert_roundtrip_array(sample, np.array([1, 2, 3], dtype="int32")) _assert_roundtrip_array(sample, np.array([1, 2, 3], dtype="int32"))
@@ -31,24 +34,45 @@ def _test_sample(sample: Sample) -> None:
_assert_roundtrip_array(sample, np.array([1.0, 2.0, 3.0], dtype="float32")) _assert_roundtrip_array(sample, np.array([1.0, 2.0, 3.0], dtype="float32"))
_assert_roundtrip_array(sample, np.array([1.0, 2.0, 3.0], dtype="float64")) _assert_roundtrip_array(sample, np.array([1.0, 2.0, 3.0], dtype="float64"))
_assert_roundtrip_array(sample, np.array(["A", "BB", "CCC"], dtype="S")) _assert_roundtrip_array(sample, np.array(["A", "BB", "CCC"], dtype="S"))
assert sample.get_scalar("unknown-key") is None
assert sample.get_array("unknown-key") is None assert sample.get_array("unknown-key") is None
_assert_roundtrip_sparse(
def _assert_roundtrip_array(sample: Sample, expected: Any) -> None: sample,
sample.put_array("key", expected) coo_matrix(
actual = sample.get_array("key") [
assert actual is not None [1, 0, 0],
assert isinstance(actual, np.ndarray) [0, 2, 3],
assert actual.dtype == expected.dtype [0, 0, 4],
assert (actual == expected).all() ],
dtype=float,
),
)
assert sample.get_sparse("unknown-key") is None
def _assert_roundtrip_scalar(sample: Sample, expected: Any) -> None: def _assert_roundtrip_array(sample: Sample, original: np.ndarray) -> None:
sample.put_scalar("key", expected) sample.put_array("key", original)
actual = sample.get_scalar("key") recovered = sample.get_array("key")
assert actual == expected assert recovered is not None
assert actual is not None assert isinstance(recovered, np.ndarray)
assert recovered.dtype == original.dtype
assert (recovered == original).all()
def _assert_roundtrip_scalar(sample: Sample, original: Any) -> None:
sample.put_scalar("key", original)
recovered = sample.get_scalar("key")
assert recovered == original
assert recovered is not None
assert isinstance( assert isinstance(
actual, expected.__class__ recovered, original.__class__
), f"Expected {expected.__class__}, found {actual.__class__} instead" ), f"Expected {original.__class__}, found {recovered.__class__} instead"
def _assert_roundtrip_sparse(sample: Sample, original: coo_matrix) -> None:
sample.put_sparse("key", original)
recovered = sample.get_sparse("key")
assert recovered is not None
assert isinstance(recovered, coo_matrix)
assert recovered.dtype == original.dtype
assert (original != recovered).sum() == 0