Skip to content

Commit

Permalink
[Feature] Trial handling in Elephant (#579)
Browse files Browse the repository at this point in the history

---------

Co-authored-by: Michael Denker <[email protected]>
  • Loading branch information
Moritz-Alexander-Kern and mdenker authored Oct 31, 2023
1 parent 8bac14c commit d3a35a2
Show file tree
Hide file tree
Showing 8 changed files with 1,046 additions and 111 deletions.
8 changes: 8 additions & 0 deletions doc/_templates/autosummary/trials_class.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{{ fullname | escape | underline }}

.. currentmodule:: {{ module }}

.. autoclass:: {{ objname }}
:special-members: __contains__,__iter__,__len__,__add__,__sub__,__mul__,__div__,__neg__
:members:
:exclude-members: __getitem__,__init__
9 changes: 5 additions & 4 deletions doc/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ Spike trains
:maxdepth: 2

reference/spike_train_generation


********************************
LFPs and spike trains (combined)
Expand Down Expand Up @@ -77,14 +77,15 @@ Waveforms

reference/waveform_features

********************************
Alternative data representations
********************************
********************
Data Representations
********************

.. toctree::
:maxdepth: 1

reference/conversion
reference/trials

*************
Miscellaneous
Expand Down
5 changes: 5 additions & 0 deletions doc/reference/trials.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
=====================
Trial representations
=====================

.. automodule:: elephant.trials
31 changes: 16 additions & 15 deletions elephant/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,28 @@
:license: Modified BSD, see LICENSE.txt for details.
"""

from . import (statistics,
spike_train_generation,
spike_train_synchrony,
spike_train_correlation,
unitary_event_analysis,
from . import (cell_assembly_detection,
change_point_detection,
conversion,
cubic,
spectral,
current_source_density,
gpfa,
kernels,
neo_tools,
phase_analysis,
signal_processing,
spade,
spectral,
spike_train_correlation,
spike_train_dissimilarity,
spike_train_generation,
spike_train_surrogates,
signal_processing,
current_source_density,
change_point_detection,
phase_analysis,
spike_train_synchrony,
sta,
conversion,
neo_tools,
cell_assembly_detection,
spade,
trials,
unitary_event_analysis,
waveform_features,
gpfa)
statistics)

# not included modules on purpose:
# parallel: avoid warns when elephant is imported
Expand Down
181 changes: 152 additions & 29 deletions elephant/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,17 @@
import warnings

import neo
from neo.core.spiketrainlist import SpikeTrainList
import numpy as np
import quantities as pq
import scipy.stats
import scipy.signal
from numpy import ndarray
from scipy.special import erf
from typing import Union

import elephant.conversion as conv
import elephant.kernels as kernels
import elephant.trials
from elephant.conversion import BinnedSpikeTrain
from elephant.utils import deprecated_alias, check_neo_consistency, \
is_time_quantity, round_binning_errors
Expand Down Expand Up @@ -601,7 +603,8 @@ def lvr(time_intervals, R=5*pq.ms, with_nan=False):
@deprecated_alias(spiketrain='spiketrains')
def instantaneous_rate(spiketrains, sampling_period, kernel='auto',
cutoff=5.0, t_start=None, t_stop=None, trim=False,
center_kernel=True, border_correction=False):
center_kernel=True, border_correction=False,
pool_trials=False, pool_spike_trains=False):
r"""
Estimates instantaneous firing rate by kernel convolution.
Expand All @@ -610,9 +613,12 @@ def instantaneous_rate(spiketrains, sampling_period, kernel='auto',
Parameters
----------
spiketrains : neo.SpikeTrain or list of neo.SpikeTrain
Neo object(s) that contains spike times, the unit of the time stamps,
and `t_start` and `t_stop` of the spike train.
spiketrains : neo.SpikeTrain, list of neo.SpikeTrain or elephant.trials.Trials # noqa
Input spike train(s) for which the instantaneous firing rate is
calculated. If a list of spike trains is supplied, the parameter
pool_spike_trains determines the behavior of the function. If a Trials
object is supplied, the behavior is determined by the parameters
pool_spike_trains (within a trial) and pool_trials (across trials).
sampling_period : pq.Quantity
Time stamp resolution of the spike times. The same resolution will
be assumed for the kernel.
Expand Down Expand Up @@ -680,6 +686,21 @@ def instantaneous_rate(spiketrains, sampling_period, kernel='auto',
these spike train borders under the assumption that the rate does not
change strongly.
Only possible in the case of a Gaussian kernel.
Default: False
pool_trials: bool, optional
If true, calculate firing rates averaged over trials if spiketrains is
of type elephant.trials.Trials
Has no effect for single spike train or lists of spike trains.
Default: False
pool_spike_trains: bool, optional
If true, calculate firing rates averaged over spike trains. If the
input is a Trials object, spike trains are pooled across spike trains
within each trial, and pool_trials determines whether spike trains are
additionally pooled across trials.
Has no effect for a single spike train.
Default: False
Returns
Expand Down Expand Up @@ -788,6 +809,86 @@ def instantaneous_rate(spiketrains, sampling_period, kernel='auto',
[0.05842767]])
"""
if isinstance(spiketrains, elephant.trials.Trials):
kwargs = {
'kernel': kernel,
'cutoff': cutoff,
't_start': t_start,
't_stop': t_stop,
'trim': trim,
'center_kernel': center_kernel,
'border_correction': border_correction,
'pool_trials': False,
'pool_spike_trains': False,
}

if pool_trials:
list_of_lists_of_spiketrains = [
spiketrains.get_spiketrains_from_trial_as_list(
trial_id=trial_no)
for trial_no in range(spiketrains.n_trials)]

spiketrains_cross_trials = (
[list_of_lists_of_spiketrains[trial_no][spiketrain_idx]
for trial_no in range(spiketrains.n_trials)]
for spiketrain_idx, spiketrain in
enumerate(list_of_lists_of_spiketrains[0]))

rates_cross_trials = [instantaneous_rate(spiketrain,
sampling_period,
**kwargs)
for spiketrain in spiketrains_cross_trials]

average_rate_cross_trials = (
np.mean(rates, axis=1) for rates in rates_cross_trials)
if pool_spike_trains:
average_rate = np.mean(list(average_rate_cross_trials), axis=0)
analog_signal = rates_cross_trials[0]

return (neo.AnalogSignal(
signal=average_rate,
sampling_period=analog_signal.sampling_period,
units=analog_signal.units,
t_start=analog_signal.t_start,
t_stop=analog_signal.t_stop,
kernel=analog_signal.annotations)
)

list_of_average_rates_cross_trial = neo.AnalogSignal(
signal=np.array(list(average_rate_cross_trials)).transpose(),
sampling_period=rates_cross_trials[0].sampling_period,
units=rates_cross_trials[0].units,
t_start=rates_cross_trials[0].t_start,
t_stop=rates_cross_trials[0].t_stop,
kernel=rates_cross_trials[0].annotations)

return list_of_average_rates_cross_trial

if not pool_trials and not pool_spike_trains:
return [instantaneous_rate(
spiketrains.get_spiketrains_from_trial_as_list(
trial_id=trial_no), sampling_period, **kwargs)
for trial_no in range(spiketrains.n_trials)]

if not pool_trials and pool_spike_trains:
rates = [instantaneous_rate(
spiketrains.get_spiketrains_from_trial_as_list(
trial_id=trial_no), sampling_period, **kwargs)
for trial_no in range(spiketrains.n_trials)]

average_rates = (np.mean(rate, axis=1) for rate in rates)

list_of_average_rates_over_spiketrains = [
neo.AnalogSignal(signal=average_rate,
sampling_period=analog_signal.sampling_period,
units=analog_signal.units,
t_start=analog_signal.t_start,
t_stop=analog_signal.t_stop,
kernel=analog_signal.annotations)
for average_rate, analog_signal in zip(average_rates, rates)]

return list_of_average_rates_over_spiketrains

def optimal_kernel(st):
width_sigma = None
if len(st) > 0:
Expand Down Expand Up @@ -930,6 +1031,10 @@ def optimal_kernel(st):
sigma=str(kernel.sigma),
invert=kernel.invert)

if isinstance(spiketrains, neo.core.spiketrainlist.SpikeTrainList) and (
pool_spike_trains):
rate = np.mean(rate, axis=1)

rate = neo.AnalogSignal(signal=rate,
sampling_period=sampling_period,
units=pq.Hz, t_start=t_start, t_stop=t_stop,
Expand Down Expand Up @@ -1035,7 +1140,8 @@ def time_histogram(spiketrains, bin_size, t_start=None, t_stop=None,
... neo.SpikeTrain([0.3, 4.5, 6.7, 9.3], t_stop=10, units='s'),
... neo.SpikeTrain([0.7, 4.3, 8.2], t_stop=10, units='s')
... ]
>>> hist = statistics.time_histogram(spiketrains, bin_size=1 * pq.s)
>>> hist = statistics.time_histogram(spiketrains,
... bin_size=1 * pq.s)
>>> hist
<AnalogSignal(array([[2],
[0],
Expand All @@ -1053,32 +1159,49 @@ def time_histogram(spiketrains, bin_size, t_start=None, t_stop=None,
"""
# Bin the spike trains and sum across columns
bs = BinnedSpikeTrain(spiketrains, t_start=t_start, t_stop=t_stop,
bin_size=bin_size)

if binary:
bs = bs.binarize(copy=False)
bin_hist = bs.get_num_of_spikes(axis=0)
# Flatten array
bin_hist = np.ravel(bin_hist)
# Renormalise the histogram
if output == 'counts':
# Raw
bin_hist = pq.Quantity(bin_hist, units=pq.dimensionless, copy=False)
elif output == 'mean':
# Divide by number of input spike trains
bin_hist = pq.Quantity(bin_hist / len(spiketrains),
units=pq.dimensionless, copy=False)
elif output == 'rate':
# Divide by number of input spike trains and bin width
bin_hist = bin_hist / (len(spiketrains) * bin_size)
binned_spiketrain = BinnedSpikeTrain(spiketrains,
t_start=t_start,
t_stop=t_stop, bin_size=bin_size
).binarize(copy=False)
else:
binned_spiketrain = BinnedSpikeTrain(spiketrains,
t_start=t_start,
t_stop=t_stop, bin_size=bin_size
)

bin_hist: Union[int, ndarray] = binned_spiketrain.get_num_of_spikes(axis=0)
# Flatten array
bin_hist.ravel()

# Re-normalise the histogram according to desired output

def _counts() -> pq.Quantity:
# 'counts': spike counts at each bin (as integer numbers).
return pq.Quantity(bin_hist, units=pq.dimensionless, copy=False)

def _mean() -> pq.Quantity:
# 'mean': mean spike counts per spike train.
return pq.Quantity(bin_hist / len(spiketrains),
units=pq.dimensionless, copy=False)

def _rate() -> pq.Quantity:
# 'rate': mean spike rate per spike train. Like 'mean', but the
# counts are additionally normalized by the bin width.
return bin_hist / (len(spiketrains) * bin_size)

output_mapping = {"counts": _counts, "mean": _mean, "rate": _rate}
try:
normalise_func = output_mapping.get(output)
normalised_bin_hist = normalise_func()
except TypeError:
raise ValueError(f'Parameter output ({output}) is not valid.')

return neo.AnalogSignal(signal=np.expand_dims(bin_hist, axis=1),
sampling_period=bin_size, units=bin_hist.units,
t_start=bs.t_start, normalization=output,
copy=False)
return neo.AnalogSignal(signal=np.expand_dims(normalised_bin_hist, axis=1),
sampling_period=bin_size,
units=normalised_bin_hist.units,
t_start=binned_spiketrain.t_start,
normalization=output, copy=False)


@deprecated_alias(binsize='bin_size')
Expand Down Expand Up @@ -1343,7 +1466,7 @@ def pdf(self):
`t_start + j * binsize` and `t_start + (j + 1) * binsize`.
"""
norm_hist = self.complexity_histogram / self.complexity_histogram.sum()
# Convert the Complexity pdf to an neo.AnalogSignal
# Convert the Complexity pdf to a neo.AnalogSignal
pdf = neo.AnalogSignal(
np.expand_dims(norm_hist, axis=1),
units=pq.dimensionless,
Expand Down
Loading

0 comments on commit d3a35a2

Please sign in to comment.