mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Implement sample.{get,put}_sparse
This commit is contained in:
@@ -5,6 +5,7 @@ import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from copy import deepcopy
|
||||
from typing import Dict, Optional, Any, Union, List, Tuple, cast, Set
|
||||
from scipy.sparse import coo_matrix
|
||||
|
||||
import h5py
|
||||
import numpy as np
|
||||
@@ -80,6 +81,14 @@ class Sample(ABC):
|
||||
def get_array(self, key: str) -> Optional[np.ndarray]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def put_sparse(self, key: str, value: coo_matrix) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_sparse(self, key: str) -> Optional[coo_matrix]:
|
||||
pass
|
||||
|
||||
def get_set(self, key: str) -> Set:
|
||||
v = self.get_vector(key)
|
||||
if v:
|
||||
@@ -118,6 +127,10 @@ class Sample(ABC):
|
||||
assert isinstance(value, np.ndarray)
|
||||
assert value.dtype.kind in "biufS", f"Unsupported dtype: {value.dtype}"
|
||||
|
||||
def _assert_is_sparse(self, value: Any) -> None:
|
||||
assert isinstance(value, coo_matrix)
|
||||
self._assert_supported(value.data)
|
||||
|
||||
|
||||
class MemorySample(Sample):
|
||||
"""Dictionary-like class that stores training data in-memory."""
|
||||
@@ -197,6 +210,17 @@ class MemorySample(Sample):
|
||||
def get_array(self, key: str) -> Optional[np.ndarray]:
|
||||
return cast(Optional[np.ndarray], self._get(key))
|
||||
|
||||
@overrides
|
||||
def put_sparse(self, key: str, value: coo_matrix) -> None:
|
||||
if value is None:
|
||||
return
|
||||
self._assert_is_sparse(value)
|
||||
self._put(key, value)
|
||||
|
||||
@overrides
|
||||
def get_sparse(self, key: str) -> Optional[coo_matrix]:
|
||||
return cast(Optional[coo_matrix], self._get(key))
|
||||
|
||||
|
||||
class Hdf5Sample(Sample):
|
||||
"""
|
||||
@@ -351,6 +375,26 @@ class Hdf5Sample(Sample):
|
||||
return None
|
||||
return self.file[key][:]
|
||||
|
||||
@overrides
|
||||
def put_sparse(self, key: str, value: coo_matrix) -> None:
|
||||
if value is None:
|
||||
return
|
||||
self._assert_is_sparse(value)
|
||||
self.put_array(f"{key}_row", value.row)
|
||||
self.put_array(f"{key}_col", value.col)
|
||||
self.put_array(f"{key}_data", value.data)
|
||||
|
||||
@overrides
|
||||
def get_sparse(self, key: str) -> Optional[coo_matrix]:
|
||||
row = self.get_array(f"{key}_row")
|
||||
if row is None:
|
||||
return None
|
||||
col = self.get_array(f"{key}_col")
|
||||
data = self.get_array(f"{key}_data")
|
||||
assert col is not None
|
||||
assert data is not None
|
||||
return coo_matrix((data, (row, col)))
|
||||
|
||||
|
||||
def _pad(veclist: VectorList) -> Tuple[VectorList, List[int]]:
|
||||
veclist = deepcopy(veclist)
|
||||
|
||||
Reference in New Issue
Block a user