From d3a35a2951d4d8559fbed7a99a12e5359b5d3d42 Mon Sep 17 00:00:00 2001 From: Moritz Kern <92092328+Moritz-Alexander-Kern@users.noreply.github.com> Date: Tue, 31 Oct 2023 17:37:49 +0100 Subject: [PATCH] [Feature] Trial handling in Elephant (#579) --------- Co-authored-by: Michael Denker --- doc/_templates/autosummary/trials_class.rst | 8 + doc/modules.rst | 9 +- doc/reference/trials.rst | 5 + elephant/__init__.py | 31 +- elephant/statistics.py | 181 +++++++-- elephant/test/test_statistics.py | 163 +++++--- elephant/test/test_trials.py | 333 +++++++++++++++ elephant/trials.py | 427 ++++++++++++++++++++ 8 files changed, 1046 insertions(+), 111 deletions(-) create mode 100644 doc/_templates/autosummary/trials_class.rst create mode 100644 doc/reference/trials.rst create mode 100644 elephant/test/test_trials.py create mode 100644 elephant/trials.py diff --git a/doc/_templates/autosummary/trials_class.rst b/doc/_templates/autosummary/trials_class.rst new file mode 100644 index 000000000..5742424c4 --- /dev/null +++ b/doc/_templates/autosummary/trials_class.rst @@ -0,0 +1,8 @@ +{{ fullname | escape | underline }} + +.. currentmodule:: {{ module }} + +.. autoclass:: {{ objname }} + :special-members: __contains__,__iter__,__len__,__add__,__sub__,__mul__,__div__,__neg__ + :members: + :exclude-members: __getitem__,__init__ diff --git a/doc/modules.rst b/doc/modules.rst index 0c07f6bbf..96cc8bc18 100644 --- a/doc/modules.rst +++ b/doc/modules.rst @@ -46,7 +46,7 @@ Spike trains :maxdepth: 2 reference/spike_train_generation - + ******************************** LFPs and spike trains (combined) @@ -77,14 +77,15 @@ Waveforms reference/waveform_features -******************************** -Alternative data representations -******************************** +******************** +Data Representations +******************** .. toctree:: :maxdepth: 1 reference/conversion + reference/trials ************* Miscellaneous diff --git a/doc/reference/trials.rst b/doc/reference/trials.rst new file mode 100644 index 000000000..bf1815596 --- /dev/null +++ b/doc/reference/trials.rst @@ -0,0 +1,5 @@ +===================== +Trial representations +===================== + +.. automodule:: elephant.trials diff --git a/elephant/__init__.py b/elephant/__init__.py index 6c29234d6..f542a3ff0 100644 --- a/elephant/__init__.py +++ b/elephant/__init__.py @@ -6,27 +6,28 @@ :license: Modified BSD, see LICENSE.txt for details. """ -from . import (statistics, - spike_train_generation, - spike_train_synchrony, - spike_train_correlation, - unitary_event_analysis, +from . import (cell_assembly_detection, + change_point_detection, + conversion, cubic, - spectral, + current_source_density, + gpfa, kernels, + neo_tools, + phase_analysis, + signal_processing, + spade, + spectral, + spike_train_correlation, spike_train_dissimilarity, + spike_train_generation, spike_train_surrogates, - signal_processing, - current_source_density, - change_point_detection, - phase_analysis, + spike_train_synchrony, sta, - conversion, - neo_tools, - cell_assembly_detection, - spade, + trials, + unitary_event_analysis, waveform_features, - gpfa) + statistics) # not included modules on purpose: # parallel: avoid warns when elephant is imported diff --git a/elephant/statistics.py b/elephant/statistics.py index 139c73744..f98922b32 100644 --- a/elephant/statistics.py +++ b/elephant/statistics.py @@ -68,15 +68,17 @@ import warnings import neo -from neo.core.spiketrainlist import SpikeTrainList import numpy as np import quantities as pq import scipy.stats import scipy.signal +from numpy import ndarray from scipy.special import erf +from typing import Union import elephant.conversion as conv import elephant.kernels as kernels +import elephant.trials from elephant.conversion import BinnedSpikeTrain from elephant.utils import deprecated_alias, check_neo_consistency, \ is_time_quantity, round_binning_errors @@ -601,7 +603,8 @@ def lvr(time_intervals, R=5*pq.ms, with_nan=False): @deprecated_alias(spiketrain='spiketrains') def instantaneous_rate(spiketrains, sampling_period, kernel='auto', cutoff=5.0, t_start=None, t_stop=None, trim=False, - center_kernel=True, border_correction=False): + center_kernel=True, border_correction=False, + pool_trials=False, pool_spike_trains=False): r""" Estimates instantaneous firing rate by kernel convolution. @@ -610,9 +613,12 @@ def instantaneous_rate(spiketrains, sampling_period, kernel='auto', Parameters ---------- - spiketrains : neo.SpikeTrain or list of neo.SpikeTrain - Neo object(s) that contains spike times, the unit of the time stamps, - and `t_start` and `t_stop` of the spike train. + spiketrains : neo.SpikeTrain, list of neo.SpikeTrain or elephant.trials.Trials # noqa + Input spike train(s) for which the instantaneous firing rate is + calculated. If a list of spike trains is supplied, the parameter + pool_spike_trains determines the behavior of the function. If a Trials + object is supplied, the behavior is determined by the parameters + pool_spike_trains (within a trial) and pool_trials (across trials). sampling_period : pq.Quantity Time stamp resolution of the spike times. The same resolution will be assumed for the kernel. @@ -680,6 +686,21 @@ def instantaneous_rate(spiketrains, sampling_period, kernel='auto', these spike train borders under the assumption that the rate does not change strongly. Only possible in the case of a Gaussian kernel. + + Default: False + pool_trials: bool, optional + If true, calculate firing rates averaged over trials if spiketrains is + of type elephant.trials.Trials + Has no effect for single spike train or lists of spike trains. + + Default: False + pool_spike_trains: bool, optional + If true, calculate firing rates averaged over spike trains. If the + input is a Trials object, spike trains are pooled across spike trains + within each trial, and pool_trials determines whether spike trains are + additionally pooled across trials. + Has no effect for a single spike train. + Default: False Returns @@ -788,6 +809,86 @@ def instantaneous_rate(spiketrains, sampling_period, kernel='auto', [0.05842767]]) """ + if isinstance(spiketrains, elephant.trials.Trials): + kwargs = { + 'kernel': kernel, + 'cutoff': cutoff, + 't_start': t_start, + 't_stop': t_stop, + 'trim': trim, + 'center_kernel': center_kernel, + 'border_correction': border_correction, + 'pool_trials': False, + 'pool_spike_trains': False, + } + + if pool_trials: + list_of_lists_of_spiketrains = [ + spiketrains.get_spiketrains_from_trial_as_list( + trial_id=trial_no) + for trial_no in range(spiketrains.n_trials)] + + spiketrains_cross_trials = ( + [list_of_lists_of_spiketrains[trial_no][spiketrain_idx] + for trial_no in range(spiketrains.n_trials)] + for spiketrain_idx, spiketrain in + enumerate(list_of_lists_of_spiketrains[0])) + + rates_cross_trials = [instantaneous_rate(spiketrain, + sampling_period, + **kwargs) + for spiketrain in spiketrains_cross_trials] + + average_rate_cross_trials = ( + np.mean(rates, axis=1) for rates in rates_cross_trials) + if pool_spike_trains: + average_rate = np.mean(list(average_rate_cross_trials), axis=0) + analog_signal = rates_cross_trials[0] + + return (neo.AnalogSignal( + signal=average_rate, + sampling_period=analog_signal.sampling_period, + units=analog_signal.units, + t_start=analog_signal.t_start, + t_stop=analog_signal.t_stop, + kernel=analog_signal.annotations) + ) + + list_of_average_rates_cross_trial = neo.AnalogSignal( + signal=np.array(list(average_rate_cross_trials)).transpose(), + sampling_period=rates_cross_trials[0].sampling_period, + units=rates_cross_trials[0].units, + t_start=rates_cross_trials[0].t_start, + t_stop=rates_cross_trials[0].t_stop, + kernel=rates_cross_trials[0].annotations) + + return list_of_average_rates_cross_trial + + if not pool_trials and not pool_spike_trains: + return [instantaneous_rate( + spiketrains.get_spiketrains_from_trial_as_list( + trial_id=trial_no), sampling_period, **kwargs) + for trial_no in range(spiketrains.n_trials)] + + if not pool_trials and pool_spike_trains: + rates = [instantaneous_rate( + spiketrains.get_spiketrains_from_trial_as_list( + trial_id=trial_no), sampling_period, **kwargs) + for trial_no in range(spiketrains.n_trials)] + + average_rates = (np.mean(rate, axis=1) for rate in rates) + + list_of_average_rates_over_spiketrains = [ + neo.AnalogSignal(signal=average_rate, + sampling_period=analog_signal.sampling_period, + units=analog_signal.units, + t_start=analog_signal.t_start, + t_stop=analog_signal.t_stop, + kernel=analog_signal.annotations) + for average_rate, analog_signal in zip(average_rates, rates)] + + return list_of_average_rates_over_spiketrains + def optimal_kernel(st): width_sigma = None if len(st) > 0: @@ -930,6 +1031,10 @@ def optimal_kernel(st): sigma=str(kernel.sigma), invert=kernel.invert) + if isinstance(spiketrains, neo.core.spiketrainlist.SpikeTrainList) and ( + pool_spike_trains): + rate = np.mean(rate, axis=1) + rate = neo.AnalogSignal(signal=rate, sampling_period=sampling_period, units=pq.Hz, t_start=t_start, t_stop=t_stop, @@ -1035,7 +1140,8 @@ def time_histogram(spiketrains, bin_size, t_start=None, t_stop=None, ... neo.SpikeTrain([0.3, 4.5, 6.7, 9.3], t_stop=10, units='s'), ... neo.SpikeTrain([0.7, 4.3, 8.2], t_stop=10, units='s') ... ] - >>> hist = statistics.time_histogram(spiketrains, bin_size=1 * pq.s) + >>> hist = statistics.time_histogram(spiketrains, + ... bin_size=1 * pq.s) >>> hist pq.Quantity: + # 'counts': spike counts at each bin (as integer numbers). + return pq.Quantity(bin_hist, units=pq.dimensionless, copy=False) + + def _mean() -> pq.Quantity: + # 'mean': mean spike counts per spike train. + return pq.Quantity(bin_hist / len(spiketrains), + units=pq.dimensionless, copy=False) + + def _rate() -> pq.Quantity: + # 'rate': mean spike rate per spike train. Like 'mean', but the + # counts are additionally normalized by the bin width. + return bin_hist / (len(spiketrains) * bin_size) + + output_mapping = {"counts": _counts, "mean": _mean, "rate": _rate} + try: + normalise_func = output_mapping.get(output) + normalised_bin_hist = normalise_func() + except TypeError: raise ValueError(f'Parameter output ({output}) is not valid.') - return neo.AnalogSignal(signal=np.expand_dims(bin_hist, axis=1), - sampling_period=bin_size, units=bin_hist.units, - t_start=bs.t_start, normalization=output, - copy=False) + return neo.AnalogSignal(signal=np.expand_dims(normalised_bin_hist, axis=1), + sampling_period=bin_size, + units=normalised_bin_hist.units, + t_start=binned_spiketrain.t_start, + normalization=output, copy=False) @deprecated_alias(binsize='bin_size') @@ -1343,7 +1466,7 @@ def pdf(self): `t_start + j * binsize` and `t_start + (j + 1) * binsize`. """ norm_hist = self.complexity_histogram / self.complexity_histogram.sum() - # Convert the Complexity pdf to an neo.AnalogSignal + # Convert the Complexity pdf to a neo.AnalogSignal pdf = neo.AnalogSignal( np.expand_dims(norm_hist, axis=1), units=pq.dimensionless, diff --git a/elephant/test/test_statistics.py b/elephant/test/test_statistics.py index 4030f37db..aaa48412d 100644 --- a/elephant/test/test_statistics.py +++ b/elephant/test/test_statistics.py @@ -20,6 +20,8 @@ import elephant.kernels as kernels from elephant import statistics from elephant.spike_train_generation import StationaryPoissonProcess +from elephant.test.test_trials import _create_trials_block +from elephant.trials import TrialsFromBlock class IsiTestCase(unittest.TestCase): @@ -482,32 +484,40 @@ def test_cv2_raise_error(self): class InstantaneousRateTest(unittest.TestCase): - def setUp(self): - # create a poisson spike train: - self.st_tr = (0, 20.0) # seconds - self.st_dur = self.st_tr[1] - self.st_tr[0] # seconds - self.st_margin = 5.0 # seconds - self.st_rate = 10.0 # Hertz + @classmethod + def setUpClass(cls) -> None: + """ + Run once before tests: + """ + + block = _create_trials_block(n_trials=36) + cls.block = block + cls.trial_object = TrialsFromBlock(block, + description='trials are segments') + # create a poisson spike train: + cls.st_tr = (0, 20.0) # seconds + cls.st_dur = cls.st_tr[1] - cls.st_tr[0] # seconds + cls.st_margin = 5.0 # seconds + cls.st_rate = 10.0 # Hertz np.random.seed(19) - duration_effective = self.st_dur - 2 * self.st_margin - st_num_spikes = np.random.poisson(self.st_rate * duration_effective) + duration_effective = cls.st_dur - 2 * cls.st_margin + st_num_spikes = np.random.poisson( + cls.st_rate * duration_effective) spike_train = sorted( np.random.rand(st_num_spikes) * duration_effective + - self.st_margin) - + cls.st_margin) # convert spike train into neo objects - self.spike_train = neo.SpikeTrain(spike_train * pq.s, - t_start=self.st_tr[0] * pq.s, - t_stop=self.st_tr[1] * pq.s) - + cls.spike_train = neo.SpikeTrain(spike_train * pq.s, + t_start=cls.st_tr[0] * pq.s, + t_stop=cls.st_tr[1] * pq.s) # generation of a multiply used specific kernel - self.kernel = kernels.TriangularKernel(sigma=0.03 * pq.s) + cls.kernel = kernels.TriangularKernel(sigma=0.03 * pq.s) # calculate instantaneous rate - self.sampling_period = 0.01 * pq.s - self.inst_rate = statistics.instantaneous_rate( - self.spike_train, self.sampling_period, self.kernel, cutoff=0) + cls.sampling_period = 0.01 * pq.s + cls.inst_rate = statistics.instantaneous_rate( + cls.spike_train, cls.sampling_period, cls.kernel, cutoff=0) def test_instantaneous_rate_warnings(self): with self.assertWarns(UserWarning): @@ -646,7 +656,7 @@ def test_instantaneous_rate_regression_288(self): 10 * pq.Hz, t_start=0 * pq.s, t_stop=10 * pq.s).generate_spiketrain() kernel = kernels.AlphaKernel(sigma=5 * pq.ms, invert=True) - rate = statistics.instantaneous_rate( + _ = statistics.instantaneous_rate( spiketrain, sampling_period=sampling_period, kernel=kernel) except ValueError: self.fail('When providing a kernel on a much smaller time scale ' @@ -930,50 +940,77 @@ def test_instantaneous_rate_bin_edges(self): self.assertAlmostEqual(spike_times[3].magnitude.item(), rate.times[rate.argmax()].magnitude.item()) - def test_instantaneous_rate_border_correction(self): - np.random.seed(0) - n_spiketrains = 125 - rate = 50. * pq.Hz - t_start = 0. * pq.ms - t_stop = 1000. * pq.ms - - sampling_period = 0.1 * pq.ms - - trial_list = StationaryPoissonProcess( - rate=rate, t_start=t_start, t_stop=t_stop - ).generate_n_spiketrains(n_spiketrains) - - for correction in (True, False): - rates = [] - for trial in trial_list: - # calculate the instantaneous rate, discard extra dimension - instantaneous_rate = statistics.instantaneous_rate( - spiketrains=trial, - sampling_period=sampling_period, - kernel='auto', - border_correction=correction - ) - rates.append(instantaneous_rate) - - # The average estimated rate gives the average estimated value of - # the firing rate in each time bin. - # Note: the indexing [:, 0] is necessary to get the output an - # one-dimensional array. - average_estimated_rate = np.mean(rates, axis=0)[:, 0] - - rtol = 0.05 # Five percent of tolerance - - if correction: - self.assertLess(np.max(average_estimated_rate), - (1. + rtol) * rate.item()) - self.assertGreater(np.min(average_estimated_rate), - (1. - rtol) * rate.item()) - else: - self.assertLess(np.max(average_estimated_rate), - (1. + rtol) * rate.item()) - # The minimal rate deviates strongly in the uncorrected case. - self.assertLess(np.min(average_estimated_rate), - (1. - rtol) * rate.item()) + def test_instantaneous_rate_border_correction(self): + np.random.seed(0) + n_spiketrains = 125 + rate = 50. * pq.Hz + t_start = 0. * pq.ms + t_stop = 1000. * pq.ms + sampling_period = 0.1 * pq.ms + trial_list = StationaryPoissonProcess( + rate=rate, t_start=t_start, t_stop=t_stop + ).generate_n_spiketrains(n_spiketrains) + for correction in (True, False): + rates = [] + for trial in trial_list: + # calculate the instantaneous rate, discard extra dimension + instantaneous_rate = statistics.instantaneous_rate( + spiketrains=trial, + sampling_period=sampling_period, + kernel='auto', + border_correction=correction + ) + rates.append(instantaneous_rate) + # The average estimated rate gives the average estimated value of + # the firing rate in each time bin. + # Note: the indexing [:, 0] is necessary to get the output an + # one-dimensional array. + average_estimated_rate = np.mean(rates, axis=0)[:, 0] + rtol = 0.05 # Five percent of tolerance + if correction: + self.assertLess(np.max(average_estimated_rate), + (1. + rtol) * rate.item()) + self.assertGreater(np.min(average_estimated_rate), + (1. - rtol) * rate.item()) + else: + self.assertLess(np.max(average_estimated_rate), + (1. + rtol) * rate.item()) + # The minimal rate deviates strongly in the uncorrected case. + self.assertLess(np.min(average_estimated_rate), + (1. - rtol) * rate.item()) + + def test_instantaneous_rate_trials_pool_trials(self): + kernel = kernels.GaussianKernel(sigma=500 * pq.ms) + + rate = statistics.instantaneous_rate(self.trial_object, + sampling_period=0.1 * pq.ms, + kernel=kernel, + pool_spike_trains=False, + pool_trials=True) + self.assertIsInstance(rate, neo.core.AnalogSignal) + + def test_instantaneous_rate_list_pool_spike_trains(self): + kernel = kernels.GaussianKernel(sigma=500 * pq.ms) + + rate = statistics.instantaneous_rate( + self.trial_object.get_spiketrains_from_trial_as_list(0), + sampling_period=0.1 * pq.ms, + kernel=kernel, + pool_spike_trains=True, + pool_trials=False) + self.assertIsInstance(rate, neo.core.AnalogSignal) + self.assertEqual(rate.magnitude.shape[1], 1) + + def test_instantaneous_rate_list_of_spike_trains(self): + kernel = kernels.GaussianKernel(sigma=500 * pq.ms) + rate = statistics.instantaneous_rate( + self.trial_object.get_spiketrains_from_trial_as_list(0), + sampling_period=0.1 * pq.ms, + kernel=kernel, + pool_spike_trains=False, + pool_trials=False) + self.assertIsInstance(rate, neo.core.AnalogSignal) + self.assertEqual(rate.magnitude.shape[1], 2) class TimeHistogramTestCase(unittest.TestCase): diff --git a/elephant/test/test_trials.py b/elephant/test/test_trials.py new file mode 100644 index 000000000..11138fef3 --- /dev/null +++ b/elephant/test/test_trials.py @@ -0,0 +1,333 @@ +# -*- coding: utf-8 -*- +""" +Unit tests for the trials objects. + +:copyright: Copyright 2014-2023 by the Elephant team, see AUTHORS.txt. +:license: Modified BSD, see LICENSE.txt for details. +""" + +import unittest +import neo.utils +import quantities as pq +from neo.core import AnalogSignal + +from elephant.spike_train_generation import StationaryPoissonProcess +from elephant.trials import TrialsFromBlock, TrialsFromLists + + +def _create_trials_block(n_trials: int = 0, + n_spiketrains: int = 2, + n_analogsignals: int = 2) -> neo.core.Block: + """ Create block with n_trials, n_spiketrains and n_analog_signals """ + block = neo.Block(name='test_block') + for trial in range(n_trials): + segment = neo.Segment(name=f'No. {trial}') + spiketrains = StationaryPoissonProcess(rate=50. * pq.Hz, + t_start=0 * pq.ms, + t_stop=1000 * pq.ms + ).generate_n_spiketrains( + n_spiketrains=n_spiketrains) + analogsignals = [AnalogSignal(signal=[.01, 3.3, 9.3], units='uV', + sampling_rate=1 * pq.Hz) + ] * n_analogsignals + for spiketrain in spiketrains: + segment.spiketrains.append(spiketrain) + for analogsignal in analogsignals: + segment.analogsignals.append(analogsignal) + block.segments.append(segment) + return block + + +######### +# Tests # +######### + + +class TrialsFromBlockTestCase(unittest.TestCase): + """Tests for elephant.trials.TrialsFromBlock class""" + + @classmethod + def setUpClass(cls) -> None: + """ + Run once before tests: + """ + + block = _create_trials_block(n_trials=36) + cls.block = block + cls.trial_object = TrialsFromBlock(block, + description='trials are segments') + + def setUp(self) -> None: + """ + Run before every test: + """ + + def test_trials_from_block_description(self) -> None: + """ + Test description of the trials object. + """ + self.assertEqual(self.trial_object.description, 'trials are segments') + + def test_trials_from_block_get_item(self) -> None: + """ + Test get a trial from the trials. + """ + self.assertIsInstance(self.trial_object[0], neo.core.Segment) + + def test_trials_from_block_get_trial_as_segment(self) -> None: + """ + Test get a trial from the trials. + """ + self.assertIsInstance( + self.trial_object.get_trial_as_segment(0), + neo.core.Segment) + self.assertIsInstance( + self.trial_object.get_trial_as_segment(0).spiketrains[0], + neo.core.SpikeTrain) + self.assertIsInstance( + self.trial_object.get_trial_as_segment(0).analogsignals[0], + neo.core.AnalogSignal) + + def test_trials_from_block_get_trials_as_block(self) -> None: + """ + Test get a block from list of trials. + """ + block = self.trial_object.get_trials_as_block([0, 3, 5]) + self.assertIsInstance(block, neo.core.Block) + self.assertIsInstance(self.trial_object.get_trials_as_block(), + neo.core.Block) + self.assertEqual(len(block.segments), 3) + + def test_trials_from_block_get_trials_as_list(self) -> None: + """ + Test get a list of segments from list of trials. + """ + list_of_trials = self.trial_object.get_trials_as_list([0, 3, 5]) + self.assertIsInstance(list_of_trials, list) + self.assertIsInstance(self.trial_object.get_trials_as_list(), list) + self.assertIsInstance(list_of_trials[0], neo.core.Segment) + self.assertEqual(len(list_of_trials), 3) + + def test_trials_from_block_n_trials(self) -> None: + """ + Test get number of trials. + """ + self.assertEqual(self.trial_object.n_trials, len(self.block.segments)) + + def test_trials_from_block_n_spiketrains_trial_by_trial(self) -> None: + """ + Test get number of spiketrains per trial. + """ + self.assertEqual(self.trial_object.n_spiketrains_trial_by_trial, + [len(trial.spiketrains) for trial in + self.block.segments]) + + def test_trials_from_block_n_analogsignals_trial_by_trial(self) -> None: + """ + Test get number of analogsignals per trial. + """ + self.assertEqual(self.trial_object.n_analogsignals_trial_by_trial, + [len(trial.analogsignals) for trial in + self.block.segments]) + + def test_trials_from_block_get_spiketrains_from_trial_as_list(self + ) -> None: + """ + Test get spiketrains from trial as list + """ + self.assertIsInstance( + self.trial_object.get_spiketrains_from_trial_as_list(0), + neo.core.spiketrainlist.SpikeTrainList) + self.assertIsInstance( + self.trial_object.get_spiketrains_from_trial_as_list(0)[0], + neo.core.SpikeTrain) + + def test_trials_from_list_get_spiketrains_from_trial_as_segment(self + ) -> None: + """ + Test get spiketrains from trial as segment + """ + self.assertIsInstance( + self.trial_object.get_spiketrains_from_trial_as_segment(0), + neo.core.Segment) + self.assertIsInstance( + self.trial_object.get_spiketrains_from_trial_as_segment( + 0).spiketrains[0], neo.core.SpikeTrain) + + def test_trials_from_block_get_analogsignals_from_trial_as_list(self + ) -> None: + """ + Test get analogsignals from trial as list + """ + self.assertIsInstance( + self.trial_object.get_analogsignals_from_trial_as_list(0), list) + self.assertIsInstance( + self.trial_object.get_analogsignals_from_trial_as_list(0)[0], + neo.core.AnalogSignal) + + def test_trials_from_list_get_analogsignals_from_trial_as_segment(self) \ + -> None: + """ + Test get spiketrains from trial as segment + """ + self.assertIsInstance( + self.trial_object.get_analogsignals_from_trial_as_segment(0), + neo.core.Segment) + self.assertIsInstance( + self.trial_object.get_analogsignals_from_trial_as_segment( + 0).analogsignals[0], neo.core.AnalogSignal) + + +class TrialsFromListTestCase(unittest.TestCase): + """Tests for elephant.trials.TrialsFromList class""" + + @classmethod + def setUpClass(cls) -> None: + """ + Run once before tests: + Download the dataset from elephant_data + """ + block = _create_trials_block(n_trials=36) + + # Create Trialobject as list of lists + # add spiketrains + trial_list = [[spiketrain for spiketrain in trial.spiketrains] + for trial in block.segments] + # add analogsignals + for idx, trial in enumerate(block.segments): + for analogsignal in trial.analogsignals: + trial_list[idx].append(analogsignal) + cls.trial_list = trial_list + + cls.trial_object = TrialsFromLists(trial_list, + description='trial is a list') + + def setUp(self) -> None: + """ + Run before every test: + """ + + def test_trials_from_list_description(self) -> None: + """ + Test description of the trials object. + """ + self.assertEqual(self.trial_object.description, 'trial is a list') + + def test_trials_from_list_get_item(self) -> None: + """ + Test get a trial from the trials. + """ + self.assertIsInstance(self.trial_object[0], + neo.core.Segment) + self.assertIsInstance(self.trial_object[0].spiketrains[0], + neo.core.SpikeTrain) + self.assertIsInstance(self.trial_object[0].analogsignals[0], + neo.core.AnalogSignal) + + def test_trials_from_list_get_trial_as_segment(self) -> None: + """ + Test get a trial from the trials. + """ + self.assertIsInstance( + self.trial_object.get_trial_as_segment(0), + neo.core.Segment) + self.assertIsInstance( + self.trial_object.get_trial_as_segment(0).spiketrains[0], + neo.core.SpikeTrain) + self.assertIsInstance( + self.trial_object.get_trial_as_segment(0).analogsignals[0], + neo.core.AnalogSignal) + + def test_trials_from_list_get_trials_as_block(self) -> None: + """ + Test get a block from list of trials. + """ + block = self.trial_object.get_trials_as_block([0, 3, 5]) + self.assertIsInstance(block, neo.core.Block) + self.assertIsInstance(self.trial_object.get_trials_as_block(), + neo.core.Block) + self.assertEqual(len(block.segments), 3) + + def test_trials_from_list_get_trials_as_list(self) -> None: + """ + Test get a list of segments from list of trials. + """ + list_of_trials = self.trial_object.get_trials_as_list([0, 3, 5]) + self.assertIsInstance(list_of_trials, list) + self.assertIsInstance(self.trial_object.get_trials_as_list(), list) + self.assertIsInstance(list_of_trials[0], neo.core.Segment) + self.assertEqual(len(list_of_trials), 3) + + def test_trials_from_list_n_trials(self) -> None: + """ + Test get number of trials. + """ + self.assertEqual(self.trial_object.n_trials, len(self.trial_list)) + + def test_trials_from_list_n_spiketrains_trial_by_trial(self) -> None: + """ + Test get number of spiketrains per trial. + """ + self.assertEqual(self.trial_object.n_spiketrains_trial_by_trial, + [sum(map(lambda x: isinstance(x, neo.core.SpikeTrain), + trial)) for trial in self.trial_list]) + + def test_trials_from_list_n_analogsignals_trial_by_trial(self) -> None: + """ + Test get number of analogsignals per trial. + """ + self.assertEqual(self.trial_object.n_analogsignals_trial_by_trial, + [sum(map(lambda x: isinstance(x, + neo.core.AnalogSignal), + trial)) for trial in self.trial_list]) + + def test_trials_from_list_get_spiketrains_from_trial_as_list(self) -> None: + """ + Test get spiketrains from trial as list + """ + self.assertIsInstance( + self.trial_object.get_spiketrains_from_trial_as_list(0), + neo.core.spiketrainlist.SpikeTrainList) + self.assertIsInstance( + self.trial_object.get_spiketrains_from_trial_as_list(0)[0], + neo.core.SpikeTrain) + + def test_trials_from_list_get_spiketrains_from_trial_as_segment(self + ) -> None: + """ + Test get spiketrains from trial as segment + """ + self.assertIsInstance( + self.trial_object.get_spiketrains_from_trial_as_segment(0), + neo.core.Segment) + self.assertIsInstance( + self.trial_object.get_spiketrains_from_trial_as_segment( + 0).spiketrains[0], neo.core.SpikeTrain) + + def test_trials_from_list_get_analogsignals_from_trial_as_list(self + ) -> None: + """ + Test get analogsignals from trial as list + """ + self.assertIsInstance( + self.trial_object.get_analogsignals_from_trial_as_list(0), list) + self.assertIsInstance( + self.trial_object.get_analogsignals_from_trial_as_list(0)[0], + neo.core.AnalogSignal) + + def test_trials_from_list_get_analogsignals_from_trial_as_segment(self + ) \ + -> None: + """ + Test get spiketrains from trial as segment + """ + self.assertIsInstance( + self.trial_object.get_analogsignals_from_trial_as_segment(0), + neo.core.Segment) + self.assertIsInstance( + self.trial_object.get_analogsignals_from_trial_as_segment( + 0).analogsignals[0], neo.core.AnalogSignal) + + +if __name__ == '__main__': + unittest.main() diff --git a/elephant/trials.py b/elephant/trials.py new file mode 100644 index 000000000..a81012d51 --- /dev/null +++ b/elephant/trials.py @@ -0,0 +1,427 @@ +""" +This module defines the basic classes that represent trials in Elephant. + +Many neuroscience methods rely on the concept of repeated trials to improve the +estimate of quantities measured from the data. In the simplest case, results +from multiple trials are averaged, in other scenarios more intricate steps must +be taken in order to pool information from each repetition of a trial. Typically, +trials are considered as fixed time intervals tied to a specific event in the +experiment, such as the onset of a stimulus. + +Neo does not impose a specific way in which trials are to be represented. A +natural way to represent trials is to have a :class:`neo.Block` containing multiple +:class:`neo.Segment` objects, each representing the data of one trial. Another popular +option is to store trials as lists of lists, where the outer refers to +individual lists, and inner lists contain Neo data objects (:class:`neo.SpikeTrain` +and :class:`neo.AnalogSignal` containing individual data of each trial. + +The classes of this module abstract from these individual data representations +by introducing a set of :class:`Trials` classes with a common API. These classes +are initialized by a supported way of structuring trials, e.g., +:class:`TrialsFromBlock` for the first method described above. Internally, +:class:`Trials` class will not convert this representation, but provide access +to data in specific trials (e.g., all spike trains in trial 5) or general +information about the trial structure (e.g., how many trials are there?) via a +fixed API. Trials are consecutively numbered, starting at a trial ID of 0. + +In the release, the classes :class:`TrialsFromBlock` and +:class:`TrialsFromLists` provide this unified way to access trial data. + +.. autosummary:: + :toctree: _toctree/trials + :template: trials_class.rst + + TrialsFromBlock + TrialsFromLists + +:copyright: Copyright 2014-2023 by the Elephant team, see `doc/authors.rst`. +:license: Modified BSD, see LICENSE.txt for details. +""" + +from abc import ABCMeta, abstractmethod +from typing import List + +import neo.utils +from neo.core import Segment, Block +from neo.core.spiketrainlist import SpikeTrainList + + +class Trials: + """ + Base class for handling trials. + + This is the base class from which all trial objects inherit. + This class implements support for universally recommended arguments. + + Parameters + ---------- + description : string, optional + A textual description of the set of trials. Can be accessed via the + class attribute `description`. + Default: None. + + """ + + __metaclass__ = ABCMeta + + def __init__(self, description: str = "Trials"): + """Create an instance of the trials class.""" + self.description = description + + @abstractmethod + def __getitem__(self, trial_number: int) -> neo.core.Segment: + """Get a specific trial by number.""" + pass + + @abstractmethod + def n_trials(self) -> int: + """Get the number of trials. + + Returns + ------- + int: Number of trials + """ + pass + + @abstractmethod + def n_spiketrains_trial_by_trial(self) -> List[int]: + """Get the number of spike trains in each trial as a list. + + Returns + ------- + list of int: For each trial, contains the number of spike trains in the + trial. + """ + pass + + @abstractmethod + def n_analogsignals_trial_by_trial(self) -> List[int]: + """Get the number of analogsignal objects in each trial as a list. + + Returns + ------- + list of int: For each trial, contains the number of analogsignal objects + in the trial. + """ + pass + + @abstractmethod + def get_trial_as_segment(self, trial_id: int) -> neo.core.Segment: + """Get trial as segment. + + Parameters + ---------- + trial_id : int + Trial number to get (starting at trial ID 0). + + Returns + ------- + class:`neo.Segment`: a segment containing all spike trains and + analogsignal objects of the trial. + """ + pass + + @abstractmethod + def get_trials_as_block(self, trial_ids: List[int] = None + ) -> neo.core.Block: + """Get trials as block. + + Parameters + ---------- + trial_ids : list of int + Trial IDs to include in the Block (starting at trial ID 0). + If None is specified, all trials are returned. + Default: None + + Returns + ------- + class:`neo.Block`: a Block containing :class:`neo.Segment` objects for + each of the selected trials, each containing all spike trains and + analogsignal objects of the corresponding trial. + """ + pass + + @abstractmethod + def get_trials_as_list(self, trial_ids: List[int] = None + ) -> neo.core.spiketrainlist.SpikeTrainList: + """Get trials as list of segments. + + Parameters + ---------- + trial_ids : list of int + Trial IDs to include in the list (starting at trial ID 0). + If None is specified, all trials are returned. + Default: None + + Returns + ------- + list of :class:`neo.Segment`: a list containing :class:`neo.Segment` + objects for each of the selected trials, each containing all spike + trains and analogsignal objects of the corresponding trial. + """ + pass + + @abstractmethod + def get_spiketrains_from_trial_as_list(self, trial_id: int) -> ( + neo.core.spiketrainlist.SpikeTrainList): + """ + Get all spike trains from a specific trial and return a list. + + Parameters + ---------- + trial_id : int + Trial ID to get the spike trains from (starting at trial ID 0). + + Returns + ------- + list of :class:`neo.SpikeTrain` + List of all spike trains of the trial. + """ + pass + + @abstractmethod + def get_spiketrains_from_trial_as_segment(self, trial_id: int + ) -> neo.core.Segment: + """ + Get all spike trains from a specific trial and return a Segment. + + Parameters + ---------- + trial_id : int + Trial ID to get the spike trains from (starting at trial ID 0). + + Returns + ------- + :class:`neo.Segment`: Segment containing all spike trains of the trial. + """ + pass + + @abstractmethod + def get_analogsignals_from_trial_as_list(self, trial_id: int + ) -> List[neo.core.AnalogSignal]: + """ + Get all analogsignals from a specific trial and return a list. + + Parameters + ---------- + trial_id : int + Trial ID to get the analogsignals from (starting at trial ID 0). + + Returns + ------- + list of :class`neo.AnalogSignal`: list of all analogsignal objects of + the trial. + """ + pass + + @abstractmethod + def get_analogsignals_from_trial_as_segment(self, trial_id: int + ) -> neo.core.Segment: + """ + Get all analogsignal objects from a specific trial and return a + :class:`neo.Segment`. + + Parameters + ---------- + trial_id : int + Trial ID to get the analogsignals from (starting at trial ID 0). + + Returns + ------- + class:`neo.Segment`: segment containing all analogsignal objects of + the trial. + """ + + +class TrialsFromBlock(Trials): + """ + This class implements support for handling trials from neo.Block. + + Parameters + ---------- + block : neo.Block + An instance of neo.Block containing the trials. + The structure is assumed to follow the neo representation: + A block contains multiple segments which are considered to contain the + single trials. + description : string, optional + A textual description of the set of trials. Can be accessed via the + class attribute `description`. + Default: None. + """ + + def __init__(self, block: neo.core.block, **kwargs): + self.block = block + super().__init__(**kwargs) + + def __getitem__(self, trial_number: int) -> neo.core.segment: + return self.block.segments[trial_number] + + def get_trial_as_segment(self, trial_id: int) -> neo.core.Segment: + # Get a specific trial by number as a segment + return self.__getitem__(trial_id) + + def get_trials_as_block(self, trial_ids: List[int] = None + ) -> neo.core.Block: + # Get a block of trials by trial numbers + block = Block() + if not trial_ids: + trial_ids = list(range(self.n_trials)) + for trial_number in trial_ids: + block.segments.append(self.get_trial_as_segment(trial_number)) + return block + + def get_trials_as_list(self, trial_ids: List[int] = None + ) -> List[neo.core.Segment]: + if not trial_ids: + trial_ids = list(range(self.n_trials)) + # Get a list of segments by trial numbers + return [self.get_trial_as_segment(trial_number) + for trial_number in trial_ids] + + @property + def n_trials(self) -> int: + # Get the number of trials. + return len(self.block.segments) + + @property + def n_spiketrains_trial_by_trial(self) -> List[int]: + # Get the number of SpikeTrain instances in each trial. + return [len(trial.spiketrains) for trial in self.block.segments] + + @property + def n_analogsignals_trial_by_trial(self) -> List[int]: + # Get the number of AnalogSignals instances in each trial. + return [len(trial.analogsignals) for trial in self.block.segments] + + def get_spiketrains_from_trial_as_list(self, trial_id: int = 0) -> ( + neo.core.spiketrainlist.SpikeTrainList): + # Return a list of all spike trains from a trial + return SpikeTrainList(items=[spiketrain for spiketrain in + self.block.segments[trial_id].spiketrains]) + + def get_spiketrains_from_trial_as_segment(self, trial_id: int) -> ( + neo.core.Segment): + # Return a segment with all spiketrains from a trial + segment = neo.core.Segment() + for spiketrain in self.get_spiketrains_from_trial_as_list(trial_id + ): + segment.spiketrains.append(spiketrain) + return segment + + def get_analogsignals_from_trial_as_list(self, trial_id: int) -> ( + List[neo.core.AnalogSignal]): + # Return a list of all analogsignals from a trial + return [analogsignal for analogsignal in + self.block.segments[trial_id].analogsignals] + + def get_analogsignals_from_trial_as_segment(self, trial_id: int) -> ( + neo.core.Segment): + # Return a segment with all analogsignals from a trial + segment = neo.core.Segment() + for analogsignal in self.get_analogsignals_from_trial_as_list( + trial_id): + segment.analogsignals.append(analogsignal) + return segment + + +class TrialsFromLists(Trials): + """ + This class implements support for handling trials from list of lists. + + Parameters + ---------- + list_of_trials : list of lists + A list of lists. Each list entry contains a list of neo.SpikeTrains + and/or neo.AnalogSignals. + description : string, optional + A textual description of the set of trials. Can be accessed via the + class attribute `description`. + Default: None. + """ + + def __init__(self, list_of_trials: list, **kwargs): + # Constructor + # (actual documentation is in class documentation, see above!) + self.list_of_trials = list_of_trials + super().__init__(**kwargs) + + def __getitem__(self, trial_number: int) -> neo.core.Segment: + # Get a specific trial by number + segment = Segment() + for element in self.list_of_trials[trial_number]: + if isinstance(element, neo.core.SpikeTrain): + segment.spiketrains.append(element) + if isinstance(element, neo.core.AnalogSignal): + segment.analogsignals.append(element) + return segment + + def get_trial_as_segment(self, trial_id: int) -> neo.core.Segment: + # Get a specific trial by number as a segment + return self.__getitem__(trial_id) + + def get_trials_as_block(self, trial_ids: List[int] = None + ) -> neo.core.Block: + if not trial_ids: + trial_ids = list(range(self.n_trials)) + # Get a block of trials by trial numbers + block = Block() + for trial_number in trial_ids: + block.segments.append(self.get_trial_as_segment(trial_number)) + return block + + def get_trials_as_list(self, trial_ids: List[int] = None + ) -> List[neo.core.Segment]: + if not trial_ids: + trial_ids = list(range(self.n_trials)) + # Get a list of segments by trial numbers + return [self.get_trial_as_segment(trial_number) + for trial_number in trial_ids] + + @property + def n_trials(self) -> int: + # Get the number of trials. + return len(self.list_of_trials) + + @property + def n_spiketrains_trial_by_trial(self) -> List[int]: + # Get the number of spiketrains in each trial. + return [sum(map(lambda x: isinstance(x, neo.core.SpikeTrain), trial)) + for trial in self.list_of_trials] + + @property + def n_analogsignals_trial_by_trial(self) -> List[int]: + # Get the number of analogsignals in each trial. + return [sum(map(lambda x: isinstance(x, neo.core.AnalogSignal), trial)) + for trial in self.list_of_trials] + + def get_spiketrains_from_trial_as_list(self, trial_id: int = 0) -> ( + neo.core.spiketrainlist.SpikeTrainList): + # Return a list of all spiketrains from a trial + return SpikeTrainList(items=[ + spiketrain for spiketrain in self.list_of_trials[trial_id] + if isinstance(spiketrain, neo.core.SpikeTrain)]) + + def get_spiketrains_from_trial_as_segment(self, trial_id: int) -> ( + neo.core.Segment): + # Return a segment with all spiketrains from a trial + segment = neo.core.Segment() + for spiketrain in self.get_spiketrains_from_trial_as_list(trial_id): + segment.spiketrains.append(spiketrain) + return segment + + def get_analogsignals_from_trial_as_list(self, trial_id: int) -> ( + List[neo.core.AnalogSignal]): + # Return a list of all analogsignals from a trial + return [analogsignal for analogsignal in + self.list_of_trials[trial_id] + if isinstance(analogsignal, neo.core.AnalogSignal)] + + def get_analogsignals_from_trial_as_segment(self, trial_id: int) -> ( + neo.core.Segment): + # Return a segment with all analogsignals from a trial + segment = neo.core.Segment() + for analogsignal in self.get_analogsignals_from_trial_as_list( + trial_id): + segment.analogsignals.append(analogsignal) + return segment