Skip to content

Commit

Permalink
Merge pull request #1539 from zm711/slice-none
Browse files Browse the repository at this point in the history
Add test that all rawios accept `slice(None)`
  • Loading branch information
zm711 authored Aug 29, 2024
2 parents 3b5488c + e2c1cf8 commit f0f4cc6
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 3 deletions.
13 changes: 11 additions & 2 deletions neo/rawio/medrawio.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,9 +240,18 @@ def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, strea
self.sess.set_channel_active(self._stream_info[stream_index]["raw_chans"])
num_channels = len(self._stream_info[stream_index]["raw_chans"])
self.sess.set_reference_channel(self._stream_info[stream_index]["raw_chans"][0])

# in the case we have a slice or we give an ArrayLike we need to iterate through the channels
# in order to activate them.
else:
if any(channel_indexes < 0):
raise IndexError(f"Can not index negative channels: {channel_indexes}")
if isinstance(channel_indexes, slice):
start = channel_indexes.start or 0
stop = channel_indexes.stop or len(self._stream_info[stream_index]["raw_chans"])
step = channel_indexes.step or 1
channel_indexes = [ch for ch in range(start, stop, step)]
else:
if any(channel_indexes < 0):
raise IndexError(f"Can not index negative channels: {channel_indexes}")
# Set all channels to be inactive, then selectively set some of them to be active
self.sess.set_channel_inactive("all")
for i, channel_idx in enumerate(channel_indexes):
Expand Down
25 changes: 24 additions & 1 deletion neo/test/rawiotest/rawio_compliance.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def read_analogsignals(reader):
channel_names = signal_channels["name"][mask]
channel_ids = signal_channels["id"][mask]

# acces by channel inde/ids/names should give the same chunk
# acces by channel index/ids/names should give the same chunk
channel_indexes2 = channel_indexes[::2]
channel_names2 = channel_names[::2]
channel_ids2 = channel_ids[::2]
Expand Down Expand Up @@ -214,6 +214,29 @@ def read_analogsignals(reader):
)
np.testing.assert_array_equal(raw_chunk0, raw_chunk1)

# test slice(None). This should return the same array as giving
# all channel indexes or using None as an argument in `get_analogsignal_chunk`
# see https://github.com/NeuralEnsemble/python-neo/issues/1533

raw_chunk_slice_none = reader.get_analogsignal_chunk(
block_index=block_index,
seg_index=seg_index,
i_start=i_start,
i_stop=i_stop,
stream_index=stream_index,
channel_indexes=slice(None)
)
raw_chunk_channel_indexes = reader.get_analogsignal_chunk(
block_index=block_index,
seg_index=seg_index,
i_start=i_start,
i_stop=i_stop,
stream_index=stream_index,
channel_indexes=channel_indexes
)

np.testing.assert_array_equal(raw_chunk_slice_none, raw_chunk_channel_indexes)

# test prefer_slice=True/False
if nb_chan >= 3:
for prefer_slice in (True, False):
Expand Down

0 comments on commit f0f4cc6

Please sign in to comment.