@@ -168,20 +168,29 @@ def get_chunk(self, i_start=None, i_stop=None):
168
168
169
169
# this fit the neo API >= 0.10 (with streams concept)
170
170
class AnalogSignalFromNeoRawIOSource (BaseAnalogSignalSource ):
171
- def __init__ (self , neorawio , stream_index ):
171
+ def __init__ (self , neorawio , channel_indexes = None , stream_index = None ):
172
172
173
173
BaseAnalogSignalSource .__init__ (self )
174
174
self .with_scatter = False
175
175
176
- self .neorawio = neorawio
177
- self .stream_index = stream_index
178
-
179
-
180
- self .stream_id = self .neorawio .header ['signal_streams' ][stream_index ]['id' ]
176
+ self .neorawio = neorawio
177
+
178
+ if stream_index is not None :
179
+ self .stream_index = stream_index
180
+ elif self .neorawio .signal_streams_count () == 1 :
181
+ self .stream_index = 0
182
+ else :
183
+ raise ValueError (f'Because the Neo RawIO source contains multiple signal streams ({ self .neorawio .signal_streams_count ()} ), stream_index must be provided' )
184
+
185
+ if channel_indexes is None :
186
+ channel_indexes = slice (None )
187
+ self .channel_indexes = channel_indexes
188
+
189
+ self .stream_id = self .neorawio .header ['signal_streams' ][self .stream_index ]['id' ]
181
190
signal_channels = self .neorawio .header ['signal_channels' ]
182
191
mask = signal_channels ['stream_id' ] == self .stream_id
183
- self .channels = signal_channels [mask ]
184
-
192
+ self .channels = signal_channels [mask ][ self . channel_indexes ]
193
+
185
194
self .sample_rate = self .neorawio .get_signal_sampling_rate (stream_index = self .stream_index )
186
195
187
196
#TODO: something for multi segment
@@ -222,8 +231,8 @@ def get_shape(self):
222
231
223
232
def get_chunk (self , i_start = None , i_stop = None ):
224
233
sigs = self .neorawio .get_analogsignal_chunk (block_index = self .block_index , seg_index = self .seg_index ,
225
- i_start = i_start , i_stop = i_stop , stream_index = self .stream_index ,
226
- channel_indexes = None )
234
+ i_start = i_start , i_stop = i_stop , stream_index = self .stream_index ,
235
+ channel_indexes = self . channel_indexes )
227
236
return sigs
228
237
229
238
@@ -387,7 +396,7 @@ def get_sources_from_neo_rawio(neorawio):
387
396
388
397
sources = {'signal' :[], 'epoch' :[], 'spike' :[]}
389
398
390
-
399
+
391
400
# handle of neo version
392
401
# this will be simplified in a while
393
402
if hasattr (neorawio , 'get_group_signal_channel_indexes' ):
@@ -396,23 +405,23 @@ def get_sources_from_neo_rawio(neorawio):
396
405
channel_indexes_list = neorawio .get_group_signal_channel_indexes ()
397
406
for channel_indexes in channel_indexes_list :
398
407
#one soure by channel group
399
- sources ['signal' ].append (AnalogSignalFromNeoRawIOSource_until_v9 (neorawio , channel_indexes ))
408
+ sources ['signal' ].append (AnalogSignalFromNeoRawIOSource_until_v9 (neorawio , channel_indexes = channel_indexes ))
400
409
elif hasattr (neorawio , 'get_group_channel_indexes' ):
401
410
# Neo < 0.9.0
402
411
if neorawio .signal_channels_count () > 0 :
403
412
channel_indexes_list = neorawio .get_group_signal_channel_indexes ()
404
413
for channel_indexes in channel_indexes_list :
405
414
#one soure by channel group
406
- sources ['signal' ].append (AnalogSignalFromNeoRawIOSource_until_v9 (neorawio , channel_indexes ))
415
+ sources ['signal' ].append (AnalogSignalFromNeoRawIOSource_until_v9 (neorawio , channel_indexes = channel_indexes ))
407
416
elif hasattr (neorawio , 'signal_streams_count' ):
408
- # Neo >= 0.10.0 (not release yet in march 2021)
417
+ # Neo >= 0.10.0
409
418
num_streams = neorawio .signal_streams_count ()
410
419
for stream_index in range (num_streams ):
411
420
#one soure by stream
412
- sources ['signal' ].append (AnalogSignalFromNeoRawIOSource (neorawio , stream_index ))
421
+ sources ['signal' ].append (AnalogSignalFromNeoRawIOSource (neorawio , stream_index = stream_index ))
422
+
413
423
414
424
415
-
416
425
if hasattr (neorawio , 'unit_channels_count' ):
417
426
# Neo < 0.10
418
427
if neorawio .unit_channels_count ()> 0 :
@@ -421,7 +430,7 @@ def get_sources_from_neo_rawio(neorawio):
421
430
# neo >= 0.10
422
431
if neorawio .spike_channels_count ()> 0 :
423
432
sources ['spike' ].append (SpikeFromNeoRawIOSource (neorawio , None ))
424
-
433
+
425
434
426
435
if neorawio .event_channels_count ()> 0 :
427
436
sources ['epoch' ].append (EpochFromNeoRawIOSource (neorawio , None ))
0 commit comments