diff --git a/extra_data/validation.py b/extra_data/validation.py index d40bc22d..63557ed1 100644 --- a/extra_data/validation.py +++ b/extra_data/validation.py @@ -72,13 +72,13 @@ def check_trainids(self): dataset=ds_path, ) - def _get_index(self, source, group): + def _get_index(self, path): """returns first and count dataset for specified source. This is slightly different to the same method in FileAccess as it does cut the dataset up to the trainId's dataset length. """ - ix_group = self.file.file['/INDEX/{}/{}'.format(source, group)] + ix_group = self.file.file[path] firsts = ix_group['first'][:] if 'count' in ix_group: counts = ix_group['count'][:] @@ -88,6 +88,23 @@ def _get_index(self, source, group): return firsts, counts def check_indices(self): + for src in self.file.control_sources: + first, count = self.file.get_index(src, '') + for key in self.file.get_keys(src): + ds_path = f"CONTROL/{src}/{key.replace('.', '/')}" + data_dim0 = self.file.file[ds_path].shape[0] + if np.any((first + count) > data_dim0): + max_end = (first + count).max() + self.record( + 'Index referring to data ({}) outside dataset ({})'.format( + max_end, data_dim0 + ), + dataset=ds_path, + ) + break # Recording every key separately can make a *lot* of errors + + self._check_index(f'INDEX/{src}') + for src in self.file.instrument_sources: src_groups = set() for key in self.file.get_keys(src): @@ -106,33 +123,36 @@ def check_indices(self): ) for src, group in src_groups: - record = partial(self.record, dataset='INDEX/{}/{}'.format(src, group)) - first, count = self._get_index(src, group) - - if (first.ndim != 1) or (count.ndim != 1): - record( - "Index first / count are not 1D", - first_shape=first.shape, - count_shape=count.shape, - ) - continue + self._check_index(f'INDEX/{src}/{group}') - if first.shape != count.shape: - record( - "Index first & count have different number of entries", - first_shape=first.shape, - count_shape=count.shape, - ) - continue + def _check_index(self, path): + record = partial(self.record, dataset=path) + first, count = self._get_index(path) - if first.shape != self.file.train_ids.shape: - record( - "Index has wrong number of entries", - index_shape=first.shape, - trainids_shape=self.file.train_ids.shape, - ) + if (first.ndim != 1) or (count.ndim != 1): + record( + "Index first / count are not 1D", + first_shape=first.shape, + count_shape=count.shape, + ) + return + + if first.shape != count.shape: + record( + "Index first & count have different number of entries", + first_shape=first.shape, + count_shape=count.shape, + ) + return + + if first.shape != self.file.train_ids.shape: + record( + "Index has wrong number of entries", + index_shape=first.shape, + trainids_shape=self.file.train_ids.shape, + ) - check_index_contiguous(first, count, record) + check_index_contiguous(first, count, record) def check_index_contiguous(firsts, counts, record):