Skip to content

Commit 6314f2e

Browse files
committed
Automatically use first Neo signal stream when only one exists
1 parent b4ef24e commit 6314f2e

File tree

1 file changed

+20
-14
lines changed

1 file changed

+20
-14
lines changed

ephyviewer/datasource/neosource.py

+20-14
Original file line numberDiff line numberDiff line change
@@ -168,19 +168,25 @@ def get_chunk(self, i_start=None, i_stop=None):
168168

169169
# this fit the neo API >= 0.10 (with streams concept)
170170
class AnalogSignalFromNeoRawIOSource(BaseAnalogSignalSource):
171-
def __init__(self, neorawio, stream_index, channel_indexes=None):
171+
def __init__(self, neorawio, channel_indexes=None, stream_index=None):
172172

173173
BaseAnalogSignalSource.__init__(self)
174174
self.with_scatter = False
175175

176-
self.neorawio =neorawio
177-
self.stream_index = stream_index
176+
self.neorawio = neorawio
177+
178+
if stream_index is not None:
179+
self.stream_index = stream_index
180+
elif self.neorawio.signal_streams_count() == 1:
181+
self.stream_index = 0
182+
else:
183+
raise ValueError(f'Because the Neo RawIO source contains multiple signal streams ({self.neorawio.signal_streams_count()}), stream_index must be provided')
184+
178185
if channel_indexes is None:
179186
channel_indexes = slice(None)
180187
self.channel_indexes = channel_indexes
181-
182-
183-
self.stream_id = self.neorawio.header['signal_streams'][stream_index]['id']
188+
189+
self.stream_id = self.neorawio.header['signal_streams'][self.stream_index]['id']
184190
signal_channels = self.neorawio.header['signal_channels']
185191
mask = signal_channels['stream_id'] == self.stream_id
186192
self.channels = signal_channels[mask][self.channel_indexes]
@@ -225,7 +231,7 @@ def get_shape(self):
225231

226232
def get_chunk(self, i_start=None, i_stop=None):
227233
sigs = self.neorawio.get_analogsignal_chunk(block_index=self.block_index, seg_index=self.seg_index,
228-
i_start=i_start, i_stop=i_stop, stream_index=self.stream_index,
234+
i_start=i_start, i_stop=i_stop, stream_index=self.stream_index,
229235
channel_indexes=self.channel_indexes)
230236
return sigs
231237

@@ -390,7 +396,7 @@ def get_sources_from_neo_rawio(neorawio):
390396

391397
sources = {'signal':[], 'epoch':[], 'spike':[]}
392398

393-
399+
394400
# handle of neo version
395401
# this will be simplified in a while
396402
if hasattr(neorawio, 'get_group_signal_channel_indexes'):
@@ -399,23 +405,23 @@ def get_sources_from_neo_rawio(neorawio):
399405
channel_indexes_list = neorawio.get_group_signal_channel_indexes()
400406
for channel_indexes in channel_indexes_list:
401407
#one soure by channel group
402-
sources['signal'].append(AnalogSignalFromNeoRawIOSource_until_v9(neorawio, channel_indexes))
408+
sources['signal'].append(AnalogSignalFromNeoRawIOSource_until_v9(neorawio, channel_indexes=channel_indexes))
403409
elif hasattr(neorawio, 'get_group_channel_indexes'):
404410
# Neo < 0.9.0
405411
if neorawio.signal_channels_count() > 0:
406412
channel_indexes_list = neorawio.get_group_signal_channel_indexes()
407413
for channel_indexes in channel_indexes_list:
408414
#one soure by channel group
409-
sources['signal'].append(AnalogSignalFromNeoRawIOSource_until_v9(neorawio, channel_indexes))
415+
sources['signal'].append(AnalogSignalFromNeoRawIOSource_until_v9(neorawio, channel_indexes=channel_indexes))
410416
elif hasattr(neorawio, 'signal_streams_count'):
411-
# Neo >= 0.10.0 (not release yet in march 2021)
417+
# Neo >= 0.10.0
412418
num_streams = neorawio.signal_streams_count()
413419
for stream_index in range(num_streams):
414420
#one soure by stream
415-
sources['signal'].append(AnalogSignalFromNeoRawIOSource(neorawio, stream_index))
421+
sources['signal'].append(AnalogSignalFromNeoRawIOSource(neorawio, stream_index=stream_index))
422+
416423

417424

418-
419425
if hasattr(neorawio, 'unit_channels_count'):
420426
# Neo < 0.10
421427
if neorawio.unit_channels_count()>0:
@@ -424,7 +430,7 @@ def get_sources_from_neo_rawio(neorawio):
424430
# neo >= 0.10
425431
if neorawio.spike_channels_count()>0:
426432
sources['spike'].append(SpikeFromNeoRawIOSource(neorawio, None))
427-
433+
428434

429435
if neorawio.event_channels_count()>0:
430436
sources['epoch'].append(EpochFromNeoRawIOSource(neorawio, None))

0 commit comments

Comments
 (0)