diff --git a/src/hangar/dataloaders/grouper.py b/src/hangar/dataloaders/grouper.py index 0ba618ac..e6c16241 100644 --- a/src/hangar/dataloaders/grouper.py +++ b/src/hangar/dataloaders/grouper.py @@ -1,11 +1,10 @@ import numpy as np from ..arrayset import ArraysetDataReader +from ..records.hashmachine import array_hash_digest from collections import defaultdict -import hashlib -from typing import Sequence, Union, Iterable, NamedTuple -import struct +from typing import Sequence, Union, Iterable, NamedTuple, Tuple # -------------------------- typehints --------------------------------------- @@ -21,13 +20,6 @@ # ------------------------------------------------------------------------------ -def _calculate_hash_digest(data: np.ndarray) -> str: - hasher = hashlib.blake2b(data, digest_size=20) - hasher.update(struct.pack(f'<{len(data.shape)}QB', *data.shape, data.dtype.num)) - digest = hasher.hexdigest() - return digest - - class FakeNumpyKeyDict(object): def __init__(self, group_spec_samples, group_spec_value, group_digest_spec): self._group_spec_samples = group_spec_samples @@ -35,7 +27,7 @@ def __init__(self, group_spec_samples, group_spec_value, group_digest_spec): self._group_digest_spec = group_digest_spec def __getitem__(self, key: np.ndarray) -> ArraysetSampleNames: - digest = _calculate_hash_digest(key) + digest = array_hash_digest(key) spec = self._group_digest_spec[digest] samples = self._group_spec_samples[spec] return samples @@ -53,7 +45,7 @@ def __len__(self) -> int: return len(self._group_digest_spec) def __contains__(self, key: np.ndarray) -> bool: - digest = _calculate_hash_digest(key) + digest = array_hash_digest(key) res = True if digest in self._group_digest_spec else False return res @@ -69,7 +61,7 @@ def values(self) -> Iterable[ArraysetSampleNames]: for spec in self._group_digest_spec.values(): yield self._group_spec_samples[spec] - def items(self) -> Iterable[ArraysetSampleNames]: + def items(self) -> Iterable[Tuple[np.ndarray, ArraysetSampleNames]]: for spec in self._group_digest_spec.values(): yield (self._group_spec_value[spec], self._group_spec_samples[spec]) @@ -81,11 +73,10 @@ def __repr__(self): def _repr_pretty_(self, p, cycle): res = f'Mapping: Group Data Value -> Sample Name \n' for k, v in self.items(): - res += f'\n {k} :: {v}' + res += f'\n {k} :: {v} \n' p.text(res) - # ---------------------------- MAIN METHOD ------------------------------------ @@ -112,7 +103,7 @@ def _setup(self): for spec, names in self._group_spec_samples.items(): data = self.__arrayset._fs[spec.backend].read_data(spec) self._group_spec_value[spec] = data - digest = _calculate_hash_digest(data) + digest = array_hash_digest(data) self._group_digest_spec[digest] = spec @property