diff --git a/miplearn/features/sample.py b/miplearn/features/sample.py index fa0085b..4842bc5 100644 --- a/miplearn/features/sample.py +++ b/miplearn/features/sample.py @@ -5,6 +5,7 @@ import warnings from abc import ABC, abstractmethod from copy import deepcopy from typing import Dict, Optional, Any, Union, List, Tuple, cast, Set +from scipy.sparse import coo_matrix import h5py import numpy as np @@ -80,6 +81,14 @@ class Sample(ABC): def get_array(self, key: str) -> Optional[np.ndarray]: pass + @abstractmethod + def put_sparse(self, key: str, value: coo_matrix) -> None: + pass + + @abstractmethod + def get_sparse(self, key: str) -> Optional[coo_matrix]: + pass + def get_set(self, key: str) -> Set: v = self.get_vector(key) if v: @@ -118,6 +127,10 @@ class Sample(ABC): assert isinstance(value, np.ndarray) assert value.dtype.kind in "biufS", f"Unsupported dtype: {value.dtype}" + def _assert_is_sparse(self, value: Any) -> None: + assert isinstance(value, coo_matrix) + self._assert_supported(value.data) + class MemorySample(Sample): """Dictionary-like class that stores training data in-memory.""" @@ -197,6 +210,17 @@ class MemorySample(Sample): def get_array(self, key: str) -> Optional[np.ndarray]: return cast(Optional[np.ndarray], self._get(key)) + @overrides + def put_sparse(self, key: str, value: coo_matrix) -> None: + if value is None: + return + self._assert_is_sparse(value) + self._put(key, value) + + @overrides + def get_sparse(self, key: str) -> Optional[coo_matrix]: + return cast(Optional[coo_matrix], self._get(key)) + class Hdf5Sample(Sample): """ @@ -351,6 +375,26 @@ class Hdf5Sample(Sample): return None return self.file[key][:] + @overrides + def put_sparse(self, key: str, value: coo_matrix) -> None: + if value is None: + return + self._assert_is_sparse(value) + self.put_array(f"{key}_row", value.row) + self.put_array(f"{key}_col", value.col) + self.put_array(f"{key}_data", value.data) + + @overrides + def get_sparse(self, key: str) -> Optional[coo_matrix]: + row = self.get_array(f"{key}_row") + if row is None: + return None + col = self.get_array(f"{key}_col") + data = self.get_array(f"{key}_data") + assert col is not None + assert data is not None + return coo_matrix((data, (row, col))) + def _pad(veclist: VectorList) -> Tuple[VectorList, List[int]]: veclist = deepcopy(veclist) diff --git a/tests/features/test_sample.py b/tests/features/test_sample.py index 9727713..6802848 100644 --- a/tests/features/test_sample.py +++ b/tests/features/test_sample.py @@ -5,6 +5,7 @@ from tempfile import NamedTemporaryFile from typing import Any import numpy as np +from scipy.sparse import coo_matrix from miplearn.features.sample import MemorySample, Sample, Hdf5Sample @@ -23,6 +24,8 @@ def _test_sample(sample: Sample) -> None: _assert_roundtrip_scalar(sample, True) _assert_roundtrip_scalar(sample, 1) _assert_roundtrip_scalar(sample, 1.0) + assert sample.get_scalar("unknown-key") is None + _assert_roundtrip_array(sample, np.array([True, False], dtype="bool")) _assert_roundtrip_array(sample, np.array([1, 2, 3], dtype="int16")) _assert_roundtrip_array(sample, np.array([1, 2, 3], dtype="int32")) @@ -31,24 +34,45 @@ def _test_sample(sample: Sample) -> None: _assert_roundtrip_array(sample, np.array([1.0, 2.0, 3.0], dtype="float32")) _assert_roundtrip_array(sample, np.array([1.0, 2.0, 3.0], dtype="float64")) _assert_roundtrip_array(sample, np.array(["A", "BB", "CCC"], dtype="S")) - assert sample.get_scalar("unknown-key") is None assert sample.get_array("unknown-key") is None + _assert_roundtrip_sparse( + sample, + coo_matrix( + [ + [1, 0, 0], + [0, 2, 3], + [0, 0, 4], + ], + dtype=float, + ), + ) + assert sample.get_sparse("unknown-key") is None -def _assert_roundtrip_array(sample: Sample, expected: Any) -> None: - sample.put_array("key", expected) - actual = sample.get_array("key") - assert actual is not None - assert isinstance(actual, np.ndarray) - assert actual.dtype == expected.dtype - assert (actual == expected).all() +def _assert_roundtrip_array(sample: Sample, original: np.ndarray) -> None: + sample.put_array("key", original) + recovered = sample.get_array("key") + assert recovered is not None + assert isinstance(recovered, np.ndarray) + assert recovered.dtype == original.dtype + assert (recovered == original).all() -def _assert_roundtrip_scalar(sample: Sample, expected: Any) -> None: - sample.put_scalar("key", expected) - actual = sample.get_scalar("key") - assert actual == expected - assert actual is not None + +def _assert_roundtrip_scalar(sample: Sample, original: Any) -> None: + sample.put_scalar("key", original) + recovered = sample.get_scalar("key") + assert recovered == original + assert recovered is not None assert isinstance( - actual, expected.__class__ - ), f"Expected {expected.__class__}, found {actual.__class__} instead" + recovered, original.__class__ + ), f"Expected {original.__class__}, found {recovered.__class__} instead" + + +def _assert_roundtrip_sparse(sample: Sample, original: coo_matrix) -> None: + sample.put_sparse("key", original) + recovered = sample.get_sparse("key") + assert recovered is not None + assert isinstance(recovered, coo_matrix) + assert recovered.dtype == original.dtype + assert (original != recovered).sum() == 0