Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
rlizzo committed Dec 5, 2019
1 parent a13eb46 commit 9587e93
Showing 1 changed file with 7 additions and 16 deletions.
23 changes: 7 additions & 16 deletions src/hangar/dataloaders/grouper.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import numpy as np

from ..arrayset import ArraysetDataReader
from ..records.hashmachine import array_hash_digest

from collections import defaultdict
import hashlib
from typing import Sequence, Union, Iterable, NamedTuple
import struct
from typing import Sequence, Union, Iterable, NamedTuple, Tuple


# -------------------------- typehints ---------------------------------------
Expand All @@ -21,21 +20,14 @@
# ------------------------------------------------------------------------------


def _calculate_hash_digest(data: np.ndarray) -> str:
hasher = hashlib.blake2b(data, digest_size=20)
hasher.update(struct.pack(f'<{len(data.shape)}QB', *data.shape, data.dtype.num))
digest = hasher.hexdigest()
return digest


class FakeNumpyKeyDict(object):
def __init__(self, group_spec_samples, group_spec_value, group_digest_spec):
self._group_spec_samples = group_spec_samples
self._group_spec_value = group_spec_value
self._group_digest_spec = group_digest_spec

def __getitem__(self, key: np.ndarray) -> ArraysetSampleNames:
digest = _calculate_hash_digest(key)
digest = array_hash_digest(key)
spec = self._group_digest_spec[digest]
samples = self._group_spec_samples[spec]
return samples
Expand All @@ -53,7 +45,7 @@ def __len__(self) -> int:
return len(self._group_digest_spec)

def __contains__(self, key: np.ndarray) -> bool:
digest = _calculate_hash_digest(key)
digest = array_hash_digest(key)
res = True if digest in self._group_digest_spec else False
return res

Expand All @@ -69,7 +61,7 @@ def values(self) -> Iterable[ArraysetSampleNames]:
for spec in self._group_digest_spec.values():
yield self._group_spec_samples[spec]

def items(self) -> Iterable[ArraysetSampleNames]:
def items(self) -> Iterable[Tuple[np.ndarray, ArraysetSampleNames]]:
for spec in self._group_digest_spec.values():
yield (self._group_spec_value[spec], self._group_spec_samples[spec])

Expand All @@ -81,11 +73,10 @@ def __repr__(self):
def _repr_pretty_(self, p, cycle):
res = f'Mapping: Group Data Value -> Sample Name \n'
for k, v in self.items():
res += f'\n {k} :: {v}'
res += f'\n {k} :: {v} \n'
p.text(res)



# ---------------------------- MAIN METHOD ------------------------------------


Expand All @@ -112,7 +103,7 @@ def _setup(self):
for spec, names in self._group_spec_samples.items():
data = self.__arrayset._fs[spec.backend].read_data(spec)
self._group_spec_value[spec] = data
digest = _calculate_hash_digest(data)
digest = array_hash_digest(data)
self._group_digest_spec[digest] = spec

@property
Expand Down

0 comments on commit 9587e93

Please sign in to comment.