Skip to content

Commit

Permalink
Fix/dither spike train with numpy>1.23.0 (#589)
Browse files Browse the repository at this point in the history
* add typehints
* update docstring
* write regression test for issue #586
  • Loading branch information
Moritz-Alexander-Kern authored Sep 14, 2023
1 parent e10f643 commit c3eb21f
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 138 deletions.
118 changes: 68 additions & 50 deletions elephant/spike_train_surrogates.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import random
import warnings
import copy
from typing import Union, Optional, List

import neo
import numpy as np
Expand Down Expand Up @@ -66,14 +67,14 @@
'bin_shuffling', 'isi_dithering')


def _dither_spikes_with_refractory_period(spiketrain, dither, n_surrogates,
refractory_period):
def _dither_spikes_with_refractory_period(spiketrain: neo.SpikeTrain,
dither: float,
n_surrogates: int,
refractory_period: float
) -> np.array:
units = spiketrain.units
t_start = spiketrain.t_start.rescale(units).magnitude
t_stop = spiketrain.t_stop.rescale(units).magnitude

dither = dither.rescale(units).magnitude
refractory_period = refractory_period.rescale(units).magnitude
# The initially guesses refractory period is compared to the minimal ISI.
# The smaller value is taken as the refractory to calculate with.
refractory_period = np.min(np.diff(spiketrain.magnitude),
Expand Down Expand Up @@ -108,14 +109,45 @@ def _dither_spikes_with_refractory_period(spiketrain, dither, n_surrogates,

dithered_spiketrains.append(dithered_st)

dithered_spiketrains = np.array(dithered_spiketrains) * units
dithered_spiketrains = np.array(dithered_spiketrains)

return dithered_spiketrains


def _dither_spikes(spiketrain: neo.SpikeTrain, dither: float,
n_surrogates: int, edges: bool) -> np.array:
units = spiketrain.units
t_start = spiketrain.t_start.rescale(units).magnitude.item()
t_stop = spiketrain.t_stop.rescale(units).magnitude.item()
# Main: generate the surrogates
dithered_spiketrains = \
spiketrain.magnitude.reshape((1, len(spiketrain))) \
+ 2 * dither * np.random.random_sample(
(n_surrogates, len(spiketrain))) - dither
dithered_spiketrains.sort(axis=1)

if edges:
# Leave out all spikes outside [spiketrain.t_start, spiketrain.t_stop]
dithered_spiketrains = [
train[np.all([t_start < train, train < t_stop], axis=0)]
for train in dithered_spiketrains]
else:
# Move all spikes outside
# [spiketrain.t_start, spiketrain.t_stop] to the range's ends
dithered_spiketrains = np.minimum(
np.maximum(dithered_spiketrains, t_start), t_stop)

return dithered_spiketrains


@deprecated_alias(n='n_surrogates')
def dither_spikes(spiketrain, dither, n_surrogates=1, decimals=None,
edges=True, refractory_period=None):
def dither_spikes(spiketrain: neo.SpikeTrain,
dither: pq.Quantity,
n_surrogates: Optional[int] = 1,
decimals: Optional[int] = None,
edges: Optional[bool] = True,
refractory_period: Optional[Union[pq.Quantity, None]] = None
) -> List[neo.SpikeTrain]:
"""
Generates surrogates of a spike train by spike dithering.
Expand All @@ -129,7 +161,7 @@ def dither_spikes(spiketrain, dither, n_surrogates=1, decimals=None,
Parameters
----------
spiketrain : neo.SpikeTrain
spiketrain : :class:`neo.core.SpikeTrain`
The spike train from which to generate the surrogates.
dither : pq.Quantity
Amount of dithering. A spike at time `t` is placed randomly within
Expand Down Expand Up @@ -161,8 +193,8 @@ def dither_spikes(spiketrain, dither, n_surrogates=1, decimals=None,
Returns
-------
list of neo.SpikeTrain
Each surrogate spike train obtained independently from `spiketrain` by
list of :class:`neo.core.SpikeTrain`
Each surrogate spike train obtained independently of `spiketrain` by
randomly dithering its spikes. The range of the surrogate spike trains
is the same as of `spiketrain`.
Expand All @@ -186,54 +218,40 @@ def dither_spikes(spiketrain, dither, n_surrogates=1, decimals=None,
[0.0 ms, 1000.0 ms])>]
"""
# The trivial case
if len(spiketrain) == 0:
# return the empty spiketrain n times
# return the empty spiketrain n_surrogates times
return [spiketrain.copy() for _ in range(n_surrogates)]

# Handle units
units = spiketrain.units
t_start = spiketrain.t_start.rescale(units).magnitude
t_stop = spiketrain.t_stop.rescale(units).magnitude

if refractory_period is None or refractory_period == 0:
# Main: generate the surrogates
dither = dither.rescale(units).magnitude
dithered_spiketrains = \
spiketrain.magnitude.reshape((1, len(spiketrain))) \
+ 2 * dither * np.random.random_sample(
(n_surrogates, len(spiketrain))) - dither
dithered_spiketrains.sort(axis=1)

if edges:
# Leave out all spikes outside
# [spiketrain.t_start, spiketrain.t_stop]
dithered_spiketrains = \
[train[
np.all([t_start < train, train < t_stop], axis=0)]
for train in dithered_spiketrains]
else:
# Move all spikes outside
# [spiketrain.t_start, spiketrain.t_stop] to the range's ends
dithered_spiketrains = np.minimum(
np.maximum(dithered_spiketrains, t_start),
t_stop)

dithered_spiketrains = dithered_spiketrains * units
dither = dither.rescale(units).magnitude.item()

if not refractory_period:
dithered_spiketrains = _dither_spikes(
spiketrain, dither, n_surrogates, edges)
elif isinstance(refractory_period, pq.Quantity):
refractory_period = refractory_period.rescale(units).magnitude.item()

dithered_spiketrains = _dither_spikes_with_refractory_period(
spiketrain, dither, n_surrogates, refractory_period)
else:
raise ValueError("refractory_period must be of type pq.Quantity")

# Round the surrogate data to decimal position, if requested
if decimals is not None:
dithered_spiketrains = \
dithered_spiketrains.rescale(pq.ms).round(decimals).rescale(units)

# Return the surrogates as list of neo.SpikeTrain
return [neo.SpikeTrain(train, t_start=t_start, t_stop=t_stop,
sampling_rate=spiketrain.sampling_rate)
for train in dithered_spiketrains]
if decimals:
return [neo.SpikeTrain(
(train * units).rescale(pq.ms).round(decimals).rescale(units),
t_start=spiketrain.t_start, t_stop=spiketrain.t_stop,
sampling_rate=spiketrain.sampling_rate)
for train in dithered_spiketrains]
else:
# Return the surrogates as list of neo.SpikeTrain
return [neo.SpikeTrain(
train * units,
t_start=spiketrain.t_start, t_stop=spiketrain.t_stop,
sampling_rate=spiketrain.sampling_rate)
for train in dithered_spiketrains]


@deprecated_alias(n='n_surrogates')
Expand Down Expand Up @@ -393,7 +411,7 @@ def dither_spike_train(spiketrain, shift, n_surrogates=1, decimals=None,
Parameters
----------
spiketrain : neo.SpikeTrain
spiketrain : :class:`neo.core.SpikeTrain`
The spike train from which to generate the surrogates.
shift : pq.Quantity
Amount of shift. `spiketrain` is shifted by a random amount uniformly
Expand All @@ -413,8 +431,8 @@ def dither_spike_train(spiketrain, shift, n_surrogates=1, decimals=None,
Returns
-------
list of neo.SpikeTrain
Each surrogate spike train obtained independently from `spiketrain` by
list of :class:`neo.core.SpikeTrain`
Each surrogate spike train obtained independently of `spiketrain` by
randomly dithering the whole spike train. The time range of the
surrogate spike trains is the same as in `spiketrain`.
Expand Down
Loading

0 comments on commit c3eb21f

Please sign in to comment.