Skip to content

Commit 7c82f8a

Browse files
authored
Merge pull request #159 from jpgill86/neo-backwards-compatibility
Make Neo RawIO sources compatible with Neo 0.6-0.10
2 parents 5b2b051 + 283c8ec commit 7c82f8a

File tree

2 files changed

+97
-147
lines changed

2 files changed

+97
-147
lines changed

ephyviewer/datasource/neosource.py

+90-140
Original file line numberDiff line numberDiff line change
@@ -110,88 +110,70 @@ def get_sources_from_neo_segment(neo_seg):
110110

111111
## neo.rawio stuff
112112

113-
# this can be remove when neo version 0.10 will be out
114-
class AnalogSignalFromNeoRawIOSource_until_v9(BaseAnalogSignalSource):
115-
def __init__(self, neorawio, channel_indexes=None):
116-
117-
BaseAnalogSignalSource.__init__(self)
118-
self.with_scatter = False
119-
120-
self.neorawio =neorawio
121-
if channel_indexes is None:
122-
channel_indexes = slice(None)
123-
self.channel_indexes = channel_indexes
124-
self.channels = self.neorawio.header['signal_channels'][channel_indexes]
125-
self.sample_rate = self.neorawio.get_signal_sampling_rate(channel_indexes=self.channel_indexes)
126-
127-
#TODO: something for multi segment
128-
self.block_index = 0
129-
self.seg_index = 0
130-
131-
@property
132-
def nb_channel(self):
133-
return len(self.channels)
134-
135-
def get_channel_name(self, chan=0):
136-
return self.channels[chan]['name']
137-
138-
@property
139-
def t_start(self):
140-
t_start = self.neorawio.get_signal_t_start(self.block_index, self.seg_index,
141-
channel_indexes=self.channel_indexes)
142-
return t_start
143-
144-
@property
145-
def t_stop(self):
146-
t_stop = self.t_start + self.get_length()/self.sample_rate
147-
return t_stop
148-
149-
def get_length(self):
150-
length = self.neorawio.get_signal_size(self.block_index, self.seg_index,
151-
channel_indexes=self.channel_indexes)
152-
return length
153-
154-
def get_gains(self):
155-
return self.neorawio.header['signal_channels']['gain'][self.channel_indexes]
156-
157-
def get_offsets(self):
158-
return self.neorawio.header['signal_channels']['offset'][self.channel_indexes]
159-
160-
def get_shape(self):
161-
return (self.get_length(), self.nb_channel)
162-
163-
def get_chunk(self, i_start=None, i_stop=None):
164-
sigs = self.neorawio.get_analogsignal_chunk(block_index=self.block_index, seg_index=self.seg_index,
165-
i_start=i_start, i_stop=i_stop, channel_indexes=self.channel_indexes)
166-
return sigs
167-
168-
169-
# this fit the neo API >= 0.10 (with streams concept)
170113
class AnalogSignalFromNeoRawIOSource(BaseAnalogSignalSource):
171114
def __init__(self, neorawio, channel_indexes=None, stream_index=None):
115+
"""
116+
Create an analog signal source from a Neo RawIO.
117+
118+
Parameters
119+
----------
120+
neorawio : subclass of neo.rawio.BaseRawIO
121+
Neo RawIO reader from which to load signals.
122+
channel_indexes : list of ints
123+
Indexes of signals to use. Note that for Neo>=0.10, channels within
124+
a signal stream are indexed independently of channels in other
125+
streams; for Neo<0.10, channels are indexed globally, regardless of
126+
signal group membership. If None is passed, uses all channels within
127+
a stream (or all channels globally for Neo<0.10).
128+
stream_index : int
129+
Index of signal stream to use. If only one signal stream exists,
130+
this parameter is not required. For Neo<0.10, signal streams do not
131+
exist and this parameter must be None.
132+
"""
172133

173134
BaseAnalogSignalSource.__init__(self)
174135
self.with_scatter = False
175136

176137
self.neorawio = neorawio
177-
178-
if stream_index is not None:
179-
self.stream_index = stream_index
180-
elif self.neorawio.signal_streams_count() == 1:
181-
self.stream_index = 0
182-
else:
183-
raise ValueError(f'Because the Neo RawIO source contains multiple signal streams ({self.neorawio.signal_streams_count()}), stream_index must be provided')
138+
if self.neorawio.header is None:
139+
self.neorawio.parse_header()
184140

185141
if channel_indexes is None:
186142
channel_indexes = slice(None)
187143
self.channel_indexes = channel_indexes
188144

189-
self.stream_id = self.neorawio.header['signal_streams'][self.stream_index]['id']
190-
signal_channels = self.neorawio.header['signal_channels']
191-
mask = signal_channels['stream_id'] == self.stream_id
192-
self.channels = signal_channels[mask][self.channel_indexes]
145+
if V(neo.__version__)>='0.10.0':
146+
# Neo >= 0.10
147+
# - versions 0.10+ index channels within a stream
148+
if stream_index is not None:
149+
self.stream_index = stream_index
150+
elif self.neorawio.signal_streams_count() == 1:
151+
self.stream_index = 0
152+
else:
153+
raise ValueError(f'Because the Neo RawIO source contains multiple signal streams ({self.neorawio.signal_streams_count()}), stream_index must be provided')
154+
self.stream_id = self.neorawio.header['signal_streams'][self.stream_index]['id']
155+
signal_channels = self.neorawio.header['signal_channels']
156+
mask = signal_channels['stream_id'] == self.stream_id
157+
self.channels = signal_channels[mask][self.channel_indexes]
158+
else:
159+
# Neo < 0.10
160+
# - versions 0.6-0.9 index channels globally (ignoring signal group)
161+
assert stream_index is None, f'Neo version {neo.__version__} is installed, but only Neo>=0.10 uses stream_index'
162+
self.channels = self.neorawio.header['signal_channels'][self.channel_indexes]
163+
164+
if V(neo.__version__)>='0.10.0':
165+
# Neo >= 0.10
166+
# - versions 0.10+ use stream_index as an argument often,
167+
# but also require channel_indexes for get_chunk
168+
self.signal_indexing_kwarg = {'stream_index': self.stream_index}
169+
self.get_chunk_kwargs = {'stream_index': self.stream_index, 'channel_indexes': self.channel_indexes}
170+
else:
171+
# Neo < 0.10
172+
# - versions 0.6-0.9 use channel_indexes as an argument often
173+
self.signal_indexing_kwarg = {'channel_indexes': self.channel_indexes}
174+
self.get_chunk_kwargs = {'channel_indexes': self.channel_indexes}
193175

194-
self.sample_rate = self.neorawio.get_signal_sampling_rate(stream_index=self.stream_index)
176+
self.sample_rate = self.neorawio.get_signal_sampling_rate(**self.signal_indexing_kwarg)
195177

196178
#TODO: something for multi segment
197179
self.block_index = 0
@@ -207,7 +189,7 @@ def get_channel_name(self, chan=0):
207189
@property
208190
def t_start(self):
209191
t_start = self.neorawio.get_signal_t_start(self.block_index, self.seg_index,
210-
stream_index=self.stream_index)
192+
**self.signal_indexing_kwarg)
211193
return t_start
212194

213195
@property
@@ -217,76 +199,43 @@ def t_stop(self):
217199

218200
def get_length(self):
219201
length = self.neorawio.get_signal_size(self.block_index, self.seg_index,
220-
stream_index=self.stream_index)
202+
**self.signal_indexing_kwarg)
221203
return length
222204

223205
def get_gains(self):
224-
return self.channels['gain']
206+
return self.channels['gain']
225207

226208
def get_offsets(self):
227-
return self.channels['offset']
209+
return self.channels['offset']
228210

229211
def get_shape(self):
230212
return (self.get_length(), self.nb_channel)
231213

232214
def get_chunk(self, i_start=None, i_stop=None):
233215
sigs = self.neorawio.get_analogsignal_chunk(block_index=self.block_index, seg_index=self.seg_index,
234-
i_start=i_start, i_stop=i_stop, stream_index=self.stream_index,
235-
channel_indexes=self.channel_indexes)
216+
i_start=i_start, i_stop=i_stop, **self.get_chunk_kwargs)
236217
return sigs
237218

238219

239-
# handle old neo API <0.10
240-
class SpikeFromNeoRawIOSource_until_v9(BaseSpikeSource):
241-
def __init__(self, neorawio, channel_indexes=None):
242-
self.neorawio =neorawio
243-
if channel_indexes is None:
244-
channel_indexes = slice(None)
245-
self.channel_indexes = channel_indexes
246-
247-
self.channels = self.neorawio.header['unit_channels'][channel_indexes]
248-
249-
#TODO: something for multi segment
250-
self.block_index = 0
251-
self.seg_index = 0
252-
253-
@property
254-
def nb_channel(self):
255-
return len(self.channels)
256-
257-
def get_channel_name(self, chan=0):
258-
return self.channels[chan]['name']
259-
260-
@property
261-
def t_start(self):
262-
t_start = self.neorawio.segment_t_start(self.block_index, self.seg_index)
263-
return t_start
264-
265-
@property
266-
def t_stop(self):
267-
t_stop = self.neorawio.segment_t_stop(self.block_index, self.seg_index)
268-
return t_stop
269-
270-
def get_chunk(self, chan=0, i_start=None, i_stop=None):
271-
raise(NotImplementedError)
272-
273-
def get_chunk_by_time(self, chan=0, t_start=None, t_stop=None):
274-
spike_timestamp = self.neorawio.get_spike_timestamps(block_index=self.block_index,
275-
seg_index=self.seg_index, unit_index=chan, t_start=t_start, t_stop=t_stop)
276-
277-
spike_times = self.neorawio.rescale_spike_timestamp(spike_timestamp, dtype='float64')
278220

279-
return spike_times
280221

281-
# this fit the new neo rawio API >=0.10
282222
class SpikeFromNeoRawIOSource(BaseSpikeSource):
283223
def __init__(self, neorawio, channel_indexes=None):
284224
self.neorawio =neorawio
285225
if channel_indexes is None:
286226
channel_indexes = slice(None)
287227
self.channel_indexes = channel_indexes
288228

289-
self.channels = self.neorawio.header['spike_channels'][channel_indexes]
229+
if V(neo.__version__)>='0.10.0':
230+
# Neo >= 0.10
231+
# - versions 0.10+ have spike_channels
232+
self.channels = self.neorawio.header['spike_channels'][channel_indexes]
233+
self.get_chunk_kwarg = 'spike_channel_index'
234+
else:
235+
# Neo < 0.10
236+
# - versions 0.6-0.9 have unit_channels
237+
self.channels = self.neorawio.header['unit_channels'][channel_indexes]
238+
self.get_chunk_kwarg = 'unit_index'
290239

291240
#TODO: something for multi segment
292241
self.block_index = 0
@@ -314,7 +263,7 @@ def get_chunk(self, chan=0, i_start=None, i_stop=None):
314263

315264
def get_chunk_by_time(self, chan=0, t_start=None, t_stop=None):
316265
spike_timestamp = self.neorawio.get_spike_timestamps(block_index=self.block_index,
317-
seg_index=self.seg_index, spike_channel_index=chan, t_start=t_start, t_stop=t_stop)
266+
seg_index=self.seg_index, **{self.get_chunk_kwarg: chan}, t_start=t_start, t_stop=t_stop)
318267

319268
spike_times = self.neorawio.rescale_spike_timestamp(spike_timestamp, dtype='float64')
320269

@@ -397,39 +346,40 @@ def get_sources_from_neo_rawio(neorawio):
397346
sources = {'signal':[], 'epoch':[], 'spike':[]}
398347

399348

400-
# handle of neo version
401-
# this will be simplified in a while
402-
if hasattr(neorawio, 'get_group_signal_channel_indexes'):
403-
# Neo >= 0.9.0 and < 0.10
349+
if hasattr(neorawio, 'signal_streams_count'):
350+
# Neo >= 0.10.0
351+
# - version 0.10 replaced signal groups with signal streams
352+
for stream_index in range(neorawio.signal_streams_count()):
353+
# one source per signal stream
354+
sources['signal'].append(AnalogSignalFromNeoRawIOSource(neorawio, stream_index=stream_index))
355+
elif hasattr(neorawio, 'get_group_signal_channel_indexes'):
356+
# Neo >= 0.9.0 and < 0.10
357+
# - version 0.9 renamed BaseRawIO.get_group_channel_indexes() to BaseRawIO.get_group_signal_channel_indexes()
404358
if neorawio.signal_channels_count() > 0:
405359
channel_indexes_list = neorawio.get_group_signal_channel_indexes()
406360
for channel_indexes in channel_indexes_list:
407-
#one soure by channel group
408-
sources['signal'].append(AnalogSignalFromNeoRawIOSource_until_v9(neorawio, channel_indexes=channel_indexes))
361+
# one source per channel group
362+
sources['signal'].append(AnalogSignalFromNeoRawIOSource(neorawio, channel_indexes=channel_indexes))
409363
elif hasattr(neorawio, 'get_group_channel_indexes'):
410364
# Neo < 0.9.0
365+
# - versions 0.6-0.8 have BaseRawIO.get_group_channel_indexes()
411366
if neorawio.signal_channels_count() > 0:
412-
channel_indexes_list = neorawio.get_group_signal_channel_indexes()
367+
channel_indexes_list = neorawio.get_group_channel_indexes()
413368
for channel_indexes in channel_indexes_list:
414-
#one soure by channel group
415-
sources['signal'].append(AnalogSignalFromNeoRawIOSource_until_v9(neorawio, channel_indexes=channel_indexes))
416-
elif hasattr(neorawio, 'signal_streams_count'):
417-
# Neo >= 0.10.0
418-
num_streams = neorawio.signal_streams_count()
419-
for stream_index in range(num_streams):
420-
#one soure by stream
421-
sources['signal'].append(AnalogSignalFromNeoRawIOSource(neorawio, stream_index=stream_index))
422-
369+
# one source per channel group
370+
sources['signal'].append(AnalogSignalFromNeoRawIOSource(neorawio, channel_indexes=channel_indexes))
423371

424372

425-
if hasattr(neorawio, 'unit_channels_count'):
426-
# Neo < 0.10
427-
if neorawio.unit_channels_count()>0:
428-
sources['spike'].append(SpikeFromNeoRawIOSource_until_v9(neorawio, None))
429-
elif hasattr(neorawio, 'spike_channels_count'):
430-
# neo >= 0.10
373+
if hasattr(neorawio, 'spike_channels_count'):
374+
# Neo >= 0.10
375+
# - version 0.10 renamed BaseRawIO.unit_channels_count() to BaseRawIO.spike_channels_count()
431376
if neorawio.spike_channels_count()>0:
432377
sources['spike'].append(SpikeFromNeoRawIOSource(neorawio, None))
378+
elif hasattr(neorawio, 'unit_channels_count'):
379+
# Neo < 0.10
380+
# - versions 0.6-0.9 have BaseRawIO.unit_channels_count()
381+
if neorawio.unit_channels_count()>0:
382+
sources['spike'].append(SpikeFromNeoRawIOSource(neorawio, None))
433383

434384

435385
if neorawio.event_channels_count()>0:

ephyviewer/tests/test_datasource.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -174,11 +174,11 @@ def test_spikeinterface_sources():
174174

175175

176176
if __name__=='__main__':
177-
#~ test_InMemoryAnalogSignalSource()
178-
#~ test_VideoMultiFileSource()
179-
#~ test_InMemoryEventSource()
180-
#~ test_InMemoryEpochSource()
181-
#~ test_spikesource()
177+
test_InMemoryAnalogSignalSource()
178+
test_VideoMultiFileSource()
179+
test_InMemoryEventSource()
180+
test_InMemoryEpochSource()
181+
test_spikesource()
182182
test_neo_rawio_sources()
183-
#~ test_neo_object_sources()
184-
#~ test_spikeinterface_sources()
183+
test_neo_object_sources()
184+
test_spikeinterface_sources()

0 commit comments

Comments
 (0)