Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate indexes for CONTROL data #188

Merged
merged 3 commits into from
May 28, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 46 additions & 26 deletions extra_data/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,13 @@ def check_trainids(self):
dataset=ds_path,
)

def _get_index(self, source, group):
def _get_index(self, path):
"""returns first and count dataset for specified source.

This is slightly different to the same method in FileAccess as it does
cut the dataset up to the trainId's dataset length.
"""
ix_group = self.file.file['/INDEX/{}/{}'.format(source, group)]
ix_group = self.file.file[path]
firsts = ix_group['first'][:]
if 'count' in ix_group:
counts = ix_group['count'][:]
Expand All @@ -88,6 +88,23 @@ def _get_index(self, source, group):
return firsts, counts

def check_indices(self):
for src in self.file.control_sources:
first, count = self.file.get_index(src, '')
for key in self.file.get_keys(src):
ds_path = f"CONTROL/{src}/{key.replace('.', '/')}"
data_dim0 = self.file.file[ds_path].shape[0]
if np.any((first + count) > data_dim0):
max_end = (first + count).max()
self.record(
'Index referring to data ({}) outside dataset ({})'.format(
max_end, data_dim0
),
dataset=ds_path,
)
break # Recording every key separately can make a *lot* of errors

self._check_index(f'INDEX/{src}')

for src in self.file.instrument_sources:
src_groups = set()
for key in self.file.get_keys(src):
Expand All @@ -106,33 +123,36 @@ def check_indices(self):
)

for src, group in src_groups:
record = partial(self.record, dataset='INDEX/{}/{}'.format(src, group))
first, count = self._get_index(src, group)

if (first.ndim != 1) or (count.ndim != 1):
record(
"Index first / count are not 1D",
first_shape=first.shape,
count_shape=count.shape,
)
continue
self._check_index(f'INDEX/{src}/{group}')

if first.shape != count.shape:
record(
"Index first & count have different number of entries",
first_shape=first.shape,
count_shape=count.shape,
)
continue
def _check_index(self, path):
record = partial(self.record, dataset=path)
first, count = self._get_index(path)

if first.shape != self.file.train_ids.shape:
record(
"Index has wrong number of entries",
index_shape=first.shape,
trainids_shape=self.file.train_ids.shape,
)
if (first.ndim != 1) or (count.ndim != 1):
record(
"Index first / count are not 1D",
first_shape=first.shape,
count_shape=count.shape,
)
return

if first.shape != count.shape:
record(
"Index first & count have different number of entries",
first_shape=first.shape,
count_shape=count.shape,
)
return

if first.shape != self.file.train_ids.shape:
record(
"Index has wrong number of entries",
index_shape=first.shape,
trainids_shape=self.file.train_ids.shape,
)

check_index_contiguous(first, count, record)
check_index_contiguous(first, count, record)


def check_index_contiguous(firsts, counts, record):
Expand Down