Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Integrate trials object with Unitary Event Analysis (UE) #643

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
19 changes: 7 additions & 12 deletions doc/tutorials/unitary_event_analysis.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
"\n",
"import elephant.unitary_event_analysis as ue\n",
"from elephant.datasets import download_datasets\n",
"from elephant.trials import TrialsFromBlock\n",
"\n",
"# Fix random seed to guarantee fixed output\n",
"random.seed(1224)"
Expand Down Expand Up @@ -451,10 +452,7 @@
"io = neo.io.NixIO(f\"{filepath}\",'ro')\n",
"block = io.read_block()\n",
"\n",
"spiketrains = []\n",
"# each segment contains a single trial\n",
"for ind in range(len(block.segments)):\n",
" spiketrains.append (block.segments[ind].spiketrains)\n"
"spiketrains = TrialsFromBlock(block)\n"
]
},
{
Expand All @@ -473,19 +471,16 @@
"UE = ue.jointJ_window_analysis(\n",
" spiketrains, bin_size=5*pq.ms, win_size=100*pq.ms, win_step=10*pq.ms, pattern_hash=[3])\n",
"\n",
"plot_ue(spiketrains, UE, significance_level=0.05)\n",
"plt.show()"
"plot_ue([spiketrains.get_spiketrains_from_trial_as_list(idx) for idx in range(spiketrains.n_trials)], UE, significance_level=0.05)\n",
"plt.show()\n"
Moritz-Alexander-Kern marked this conversation as resolved.
Show resolved Hide resolved
]
}
],
"metadata": {
"interpreter": {
"hash": "623e048a0474aa032839f97d38ba0837cc9041adc49a14b480c72f2df8ea99e3"
},
"kernelspec": {
"display_name": "inm-elephant",
"display_name": "Python 3",
"language": "python",
"name": "inm-elephant"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -497,7 +492,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
"version": "3.12.5"
},
"latex_envs": {
"LaTeX_envs_menu_present": true,
Expand Down
220 changes: 115 additions & 105 deletions elephant/test/test_unitary_event_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import quantities as pq
from numpy.testing import assert_array_equal

from elephant.trials import TrialsFromLists
import elephant.unitary_event_analysis as ue
from elephant.datasets import download, ELEPHANT_TMP_DIR
from numpy.testing import assert_array_almost_equal
Expand Down Expand Up @@ -324,52 +325,56 @@ def test_jointJ_window_analysis(self):
sts2 = self.sts2_neo

# joinJ_window_analysis requires the following:
# A list of spike trains(neo.SpikeTrain objects) in different trials:
data = list(zip(*[sts1,sts2]))

win_size = 100 * pq.ms
bin_size = 5 * pq.ms
win_step = 20 * pq.ms
pattern_hash = [3]
UE_dic = ue.jointJ_window_analysis(spiketrains=data,
pattern_hash=pattern_hash,
bin_size=bin_size,
win_size=win_size,
win_step=win_step)
expected_Js = np.array(
[0.57953708, 0.47348757, 0.1729669,
0.01883295, -0.21934742, -0.80608759])
expected_n_emp = np.array(
[9., 9., 7., 7., 6., 6.])
expected_n_exp = np.array(
[6.5, 6.85, 6.05, 6.6, 6.45, 8.7])
expected_rate = np.array(
[[0.02166667, 0.01861111],
[0.02277778, 0.01777778],
[0.02111111, 0.01777778],
[0.02277778, 0.01888889],
[0.02305556, 0.01722222],
[0.02388889, 0.02055556]]) * pq.kHz
expected_indecis_tril26 = [4., 4.]
expected_indecis_tril4 = [1.]
assert_array_almost_equal(UE_dic['Js'].squeeze(), expected_Js)
assert_array_almost_equal(UE_dic['n_emp'].squeeze(), expected_n_emp)
assert_array_almost_equal(UE_dic['n_exp'].squeeze(), expected_n_exp)
assert_array_almost_equal(UE_dic['rate_avg'].squeeze(), expected_rate)
assert_array_almost_equal(UE_dic['indices']['trial26'],
expected_indecis_tril26)
assert_array_almost_equal(UE_dic['indices']['trial4'],
expected_indecis_tril4)

# check the input parameters
input_params = UE_dic['input_parameters']
self.assertEqual(input_params['pattern_hash'], pattern_hash)
self.assertEqual(input_params['bin_size'], bin_size)
self.assertEqual(input_params['win_size'], win_size)
self.assertEqual(input_params['win_step'], win_step)
self.assertEqual(input_params['method'], 'analytic_TrialByTrial')
self.assertEqual(input_params['t_start'], 0 * pq.s)
self.assertEqual(input_params['t_stop'], 200 * pq.ms)
Comment on lines -330 to -372
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only indentation level changed.

# A list of spike trains(neo.SpikeTrain objects) in different trials, or trials.Trial object
test_cases = (
list(zip(*[sts1, sts2])), # list
TrialsFromLists(list(zip(*[sts1, sts2]))), # Trial object
)
for data in test_cases:
with self.subTest(data=data):
win_size = 100 * pq.ms
bin_size = 5 * pq.ms
win_step = 20 * pq.ms
pattern_hash = [3]
UE_dic = ue.jointJ_window_analysis(spiketrains=data,
pattern_hash=pattern_hash,
bin_size=bin_size,
win_size=win_size,
win_step=win_step)
expected_Js = np.array(
[0.57953708, 0.47348757, 0.1729669,
0.01883295, -0.21934742, -0.80608759])
expected_n_emp = np.array(
[9., 9., 7., 7., 6., 6.])
expected_n_exp = np.array(
[6.5, 6.85, 6.05, 6.6, 6.45, 8.7])
expected_rate = np.array(
[[0.02166667, 0.01861111],
[0.02277778, 0.01777778],
[0.02111111, 0.01777778],
[0.02277778, 0.01888889],
[0.02305556, 0.01722222],
[0.02388889, 0.02055556]]) * pq.kHz
expected_indecis_tril26 = [4., 4.]
expected_indecis_tril4 = [1.]
assert_array_almost_equal(UE_dic['Js'].squeeze(), expected_Js)
assert_array_almost_equal(UE_dic['n_emp'].squeeze(), expected_n_emp)
assert_array_almost_equal(UE_dic['n_exp'].squeeze(), expected_n_exp)
assert_array_almost_equal(UE_dic['rate_avg'].squeeze(), expected_rate)
assert_array_almost_equal(UE_dic['indices']['trial26'],
expected_indecis_tril26)
assert_array_almost_equal(UE_dic['indices']['trial4'],
expected_indecis_tril4)

# check the input parameters
input_params = UE_dic['input_parameters']
self.assertEqual(input_params['pattern_hash'], pattern_hash)
self.assertEqual(input_params['bin_size'], bin_size)
self.assertEqual(input_params['win_size'], win_size)
self.assertEqual(input_params['win_step'], win_step)
self.assertEqual(input_params['method'], 'analytic_TrialByTrial')
self.assertEqual(input_params['t_start'], 0 * pq.s)
self.assertEqual(input_params['t_stop'], 200 * pq.ms)

@staticmethod
def load_gdf2Neo(fname, trigger, t_pre, t_post):
Expand Down Expand Up @@ -501,69 +506,74 @@ def test_multiple_neurons(self):
np.random.seed(12)

# Create a list of lists containing 3 Trials with 5 spiketrains
spiketrains = \
spiketrains_poisson = \
[StationaryPoissonProcess(
rate=50 * pq.Hz, t_stop=1 * pq.s).generate_n_spiketrains(5)
for _ in range(3)]

spiketrains = list(zip(*spiketrains))
UE_dic = ue.jointJ_window_analysis(spiketrains, bin_size=5 * pq.ms,
win_size=300 * pq.ms,
win_step=100 * pq.ms)

js_expected = [[0.3978179],
[0.08131966],
[-1.4239882],
[-0.9377029],
[-0.3374434],
[-0.2043383],
[-1.001536],
[-np.inf]]
indices_expected = \
{'trial3': [12, 27, 31, 34, 27, 31, 34, 136, 136, 136],
'trial4': [4, 60, 60, 60, 117, 117, 117]}
n_emp_expected = [[5.],
[4.],
[1.],
[2.],
[2.],
[2.],
[1.],
[0.]]
n_exp_expected = [[3.5591667],
[3.4536111],
[3.3158333],
[3.8466666],
[2.370278],
[2.0811112],
[2.4011111],
[3.0533333]]
rate_expected = [[[0.042, 0.03933334, 0.048]],
[[0.04533333, 0.038, 0.05]],
[[0.046, 0.04, 0.04666667]],
[[0.05066667, 0.042, 0.046]],
[[0.04466667, 0.03666667, 0.04066667]],
[[0.04066667, 0.03533333, 0.04333333]],
[[0.03933334, 0.038, 0.038]],
[[0.04066667, 0.04866667, 0.03666667]]] * (1. / pq.ms)
input_parameters_expected = {'pattern_hash': [7],
'bin_size': 5 * pq.ms,
'win_size': 300 * pq.ms,
'win_step': 100 * pq.ms,
'method': 'analytic_TrialByTrial',
't_start': 0 * pq.s,
't_stop': 1 * pq.s, 'n_surrogates': 100}

assert_array_almost_equal(UE_dic['Js'], js_expected)
assert_array_almost_equal(UE_dic['n_emp'], n_emp_expected)
assert_array_almost_equal(UE_dic['n_exp'], n_exp_expected)
assert_array_almost_equal(UE_dic['rate_avg'], rate_expected)
self.assertEqual(sorted(UE_dic['indices'].keys()),
sorted(indices_expected.keys()))
for trial_key in indices_expected.keys():
assert_array_equal(indices_expected[trial_key],
UE_dic['indices'][trial_key])
self.assertEqual(UE_dic['input_parameters'], input_parameters_expected)
Comment on lines -510 to -566
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only indentation level changed.

test_cases = (
list(zip(*spiketrains_poisson)), # list
TrialsFromLists(list(zip(*spiketrains_poisson))), # Trial object
)
for spiketrains in test_cases:
with self.subTest(data=spiketrains):
UE_dic = ue.jointJ_window_analysis(spiketrains, bin_size=5 * pq.ms,
win_size=300 * pq.ms,
win_step=100 * pq.ms)

js_expected = [[0.3978179],
[0.08131966],
[-1.4239882],
[-0.9377029],
[-0.3374434],
[-0.2043383],
[-1.001536],
[-np.inf]]
indices_expected = \
{'trial3': [12, 27, 31, 34, 27, 31, 34, 136, 136, 136],
'trial4': [4, 60, 60, 60, 117, 117, 117]}
n_emp_expected = [[5.],
[4.],
[1.],
[2.],
[2.],
[2.],
[1.],
[0.]]
n_exp_expected = [[3.5591667],
[3.4536111],
[3.3158333],
[3.8466666],
[2.370278],
[2.0811112],
[2.4011111],
[3.0533333]]
rate_expected = [[[0.042, 0.03933334, 0.048]],
[[0.04533333, 0.038, 0.05]],
[[0.046, 0.04, 0.04666667]],
[[0.05066667, 0.042, 0.046]],
[[0.04466667, 0.03666667, 0.04066667]],
[[0.04066667, 0.03533333, 0.04333333]],
[[0.03933334, 0.038, 0.038]],
[[0.04066667, 0.04866667, 0.03666667]]] * (1. / pq.ms)
input_parameters_expected = {'pattern_hash': [7],
'bin_size': 5 * pq.ms,
'win_size': 300 * pq.ms,
'win_step': 100 * pq.ms,
'method': 'analytic_TrialByTrial',
't_start': 0 * pq.s,
't_stop': 1 * pq.s, 'n_surrogates': 100}

assert_array_almost_equal(UE_dic['Js'], js_expected)
assert_array_almost_equal(UE_dic['n_emp'], n_emp_expected)
assert_array_almost_equal(UE_dic['n_exp'], n_exp_expected)
assert_array_almost_equal(UE_dic['rate_avg'], rate_expected)
self.assertEqual(sorted(UE_dic['indices'].keys()),
sorted(indices_expected.keys()))
for trial_key in indices_expected.keys():
assert_array_equal(indices_expected[trial_key],
UE_dic['indices'][trial_key])
self.assertEqual(UE_dic['input_parameters'], input_parameters_expected)


if __name__ == '__main__':
Expand Down
5 changes: 3 additions & 2 deletions elephant/unitary_event_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
import scipy

import elephant.conversion as conv
from elephant.utils import is_binary
from elephant.utils import is_binary, trials_to_list_of_spiketrainlist

__all__ = [
"hash_from_pattern",
Expand Down Expand Up @@ -689,6 +689,7 @@ def _UE(mat, pattern_hash, method='analytic_TrialByTrial', n_surrogates=1):
return Js, rate_avg, n_exp, n_emp, indices


@trials_to_list_of_spiketrainlist
def jointJ_window_analysis(spiketrains, bin_size=5 * pq.ms,
win_size=100 * pq.ms, win_step=5 * pq.ms,
pattern_hash=None, method='analytic_TrialByTrial',
Expand All @@ -701,7 +702,7 @@ def jointJ_window_analysis(spiketrains, bin_size=5 * pq.ms,

Parameters
----------
spiketrains : list
spiketrains : :class:`elephant.trials.Trials`, list
A list of spike trains (`neo.SpikeTrain` objects) in different trials:
* 0-axis --> Trials

Expand Down
Loading