Skip to content

Commit 5b2b051

Browse files
authored
Merge pull request #157 from samuelgarcia/neo_rawio_channel_index
Add back channel_indexes in AnalogSignalFromNeoRawIOSource
2 parents 094c6b8 + 6314f2e commit 5b2b051

File tree

2 files changed

+28
-19
lines changed

2 files changed

+28
-19
lines changed

ephyviewer/datasource/neosource.py

+26-17
Original file line numberDiff line numberDiff line change
@@ -168,20 +168,29 @@ def get_chunk(self, i_start=None, i_stop=None):
168168

169169
# this fit the neo API >= 0.10 (with streams concept)
170170
class AnalogSignalFromNeoRawIOSource(BaseAnalogSignalSource):
171-
def __init__(self, neorawio, stream_index):
171+
def __init__(self, neorawio, channel_indexes=None, stream_index=None):
172172

173173
BaseAnalogSignalSource.__init__(self)
174174
self.with_scatter = False
175175

176-
self.neorawio =neorawio
177-
self.stream_index = stream_index
178-
179-
180-
self.stream_id = self.neorawio.header['signal_streams'][stream_index]['id']
176+
self.neorawio = neorawio
177+
178+
if stream_index is not None:
179+
self.stream_index = stream_index
180+
elif self.neorawio.signal_streams_count() == 1:
181+
self.stream_index = 0
182+
else:
183+
raise ValueError(f'Because the Neo RawIO source contains multiple signal streams ({self.neorawio.signal_streams_count()}), stream_index must be provided')
184+
185+
if channel_indexes is None:
186+
channel_indexes = slice(None)
187+
self.channel_indexes = channel_indexes
188+
189+
self.stream_id = self.neorawio.header['signal_streams'][self.stream_index]['id']
181190
signal_channels = self.neorawio.header['signal_channels']
182191
mask = signal_channels['stream_id'] == self.stream_id
183-
self.channels = signal_channels[mask]
184-
192+
self.channels = signal_channels[mask][self.channel_indexes]
193+
185194
self.sample_rate = self.neorawio.get_signal_sampling_rate(stream_index=self.stream_index)
186195

187196
#TODO: something for multi segment
@@ -222,8 +231,8 @@ def get_shape(self):
222231

223232
def get_chunk(self, i_start=None, i_stop=None):
224233
sigs = self.neorawio.get_analogsignal_chunk(block_index=self.block_index, seg_index=self.seg_index,
225-
i_start=i_start, i_stop=i_stop, stream_index=self.stream_index,
226-
channel_indexes=None)
234+
i_start=i_start, i_stop=i_stop, stream_index=self.stream_index,
235+
channel_indexes=self.channel_indexes)
227236
return sigs
228237

229238

@@ -387,7 +396,7 @@ def get_sources_from_neo_rawio(neorawio):
387396

388397
sources = {'signal':[], 'epoch':[], 'spike':[]}
389398

390-
399+
391400
# handle of neo version
392401
# this will be simplified in a while
393402
if hasattr(neorawio, 'get_group_signal_channel_indexes'):
@@ -396,23 +405,23 @@ def get_sources_from_neo_rawio(neorawio):
396405
channel_indexes_list = neorawio.get_group_signal_channel_indexes()
397406
for channel_indexes in channel_indexes_list:
398407
#one soure by channel group
399-
sources['signal'].append(AnalogSignalFromNeoRawIOSource_until_v9(neorawio, channel_indexes))
408+
sources['signal'].append(AnalogSignalFromNeoRawIOSource_until_v9(neorawio, channel_indexes=channel_indexes))
400409
elif hasattr(neorawio, 'get_group_channel_indexes'):
401410
# Neo < 0.9.0
402411
if neorawio.signal_channels_count() > 0:
403412
channel_indexes_list = neorawio.get_group_signal_channel_indexes()
404413
for channel_indexes in channel_indexes_list:
405414
#one soure by channel group
406-
sources['signal'].append(AnalogSignalFromNeoRawIOSource_until_v9(neorawio, channel_indexes))
415+
sources['signal'].append(AnalogSignalFromNeoRawIOSource_until_v9(neorawio, channel_indexes=channel_indexes))
407416
elif hasattr(neorawio, 'signal_streams_count'):
408-
# Neo >= 0.10.0 (not release yet in march 2021)
417+
# Neo >= 0.10.0
409418
num_streams = neorawio.signal_streams_count()
410419
for stream_index in range(num_streams):
411420
#one soure by stream
412-
sources['signal'].append(AnalogSignalFromNeoRawIOSource(neorawio, stream_index))
421+
sources['signal'].append(AnalogSignalFromNeoRawIOSource(neorawio, stream_index=stream_index))
422+
413423

414424

415-
416425
if hasattr(neorawio, 'unit_channels_count'):
417426
# Neo < 0.10
418427
if neorawio.unit_channels_count()>0:
@@ -421,7 +430,7 @@ def get_sources_from_neo_rawio(neorawio):
421430
# neo >= 0.10
422431
if neorawio.spike_channels_count()>0:
423432
sources['spike'].append(SpikeFromNeoRawIOSource(neorawio, None))
424-
433+
425434

426435
if neorawio.event_channels_count()>0:
427436
sources['epoch'].append(EpochFromNeoRawIOSource(neorawio, None))

ephyviewer/tests/test_datasource.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,6 @@ def test_spikeinterface_sources():
179179
#~ test_InMemoryEventSource()
180180
#~ test_InMemoryEpochSource()
181181
#~ test_spikesource()
182-
#~ test_neo_rawio_sources()
182+
test_neo_rawio_sources()
183183
#~ test_neo_object_sources()
184-
test_spikeinterface_sources()
184+
#~ test_spikeinterface_sources()

0 commit comments

Comments
 (0)