From e64b2800533ba420224933b74f83545407cae458 Mon Sep 17 00:00:00 2001 From: Frithjof Gressmann Date: Sat, 16 Sep 2023 14:12:07 -0500 Subject: [PATCH] Add spike train types --- src/miv_simulator/coding.py | 52 ++++++++++++++++++++++++++++++++ src/miv_simulator/typing.py | 60 +++++++++++++++++++++++++++++++++++++ tests/test_coding.py | 40 +++++++++++++++++++++++++ tests/test_typing.py | 10 +++++++ 4 files changed, 162 insertions(+) create mode 100644 src/miv_simulator/coding.py create mode 100644 src/miv_simulator/typing.py create mode 100644 tests/test_coding.py create mode 100644 tests/test_typing.py diff --git a/src/miv_simulator/coding.py b/src/miv_simulator/coding.py new file mode 100644 index 0000000..3ebbc01 --- /dev/null +++ b/src/miv_simulator/coding.py @@ -0,0 +1,52 @@ +from miv_simulator import typing as st +import numpy as np + + +def spike_times_2_binary_sparse_spike_train( + array: st.SpikeTimesLike, temporal_resolution: float +) -> st.BinarySparseSpikeTrain: + a = st.cast_spike_times(array) + bins = np.floor(a / temporal_resolution).astype(int) + # since a is sorted, maximum is last value + spike_train = np.zeros(bins[-1] + 1, dtype=np.int8) + spike_train[bins] = 1 + return spike_train + + +def binary_sparse_spike_train_2_spike_times( + array: st.BinarySparseSpikeTrainLike, temporal_resolution: float +) -> st.SpikeTimes: + a = st.cast_binary_sparse_spike_train(array) + spike_indices = np.where(a == 1)[0] + spike_times = spike_indices * temporal_resolution + return spike_times + + +def adjust_temporal_resolution( + array: st.BinarySparseSpikeTrainLike, + original_resolution: float, + target_resolution: float, +) -> st.BinarySparseSpikeTrain: + a = st.cast_binary_sparse_spike_train(array) + + ratio = target_resolution / original_resolution + if ratio == 1: + return a + + new_length = int(a.shape[0] * ratio) + new_spike_train = np.zeros(new_length, dtype=np.int8) + + # up + if ratio > 1: + for idx, val in enumerate(a): + start = int(idx * ratio) + end = int((idx + 1) * ratio) + new_spike_train[start:end] = val + + # down + elif ratio < 1: + for idx in range(0, len(a), int(1 / ratio)): + if np.any(a[idx : idx + int(1 / ratio)]): + new_spike_train[idx // int(1 / ratio)] = 1 + + return new_spike_train diff --git a/src/miv_simulator/typing.py b/src/miv_simulator/typing.py new file mode 100644 index 0000000..bf78c17 --- /dev/null +++ b/src/miv_simulator/typing.py @@ -0,0 +1,60 @@ +from numpy.typing import NDArray +import numpy as np +from typing import Annotated as EventArray, Dict + +"""Potentially unsorted or scalar data that can be transformed into `SpikeTimes`""" +SpikeTimesLike = EventArray[NDArray[np.float_], "SpikeTimesLike ..."] + +"""Sorted array of absolute spike times""" +SpikeTimes = EventArray[NDArray[np.float_], "SpikeTimes T ..."] + +# spike train encodings (RLE, delta encoding, variable time binning etc.) + +"""Binary data that can be cast to the `BinarySparseSpikeTrain` format""" +BinarySparseSpikeTrainLike = EventArray[ + NDArray, "BinarySparseSpikeTrainLike ..." +] + +"""Binary spike train representation for a given temporal resolution""" +BinarySparseSpikeTrain = EventArray[ + NDArray[np.int8], "BinarySparseSpikeTrain t_bin ..." +] + + +def _inspect(type_) -> Dict: + annotation = type_.__metadata__[0] + name, *dims = annotation.split(" ") + + return { + "annotation": annotation, + "name": name, + "dims": dims, + "dtype": type_.__origin__.__args__[1].__args__[0], + } + + +def _cast(a, a_type, r_type): # -> r_type + a_t, r_t = _inspect(a_type), _inspect(r_type) + if a_t["name"].replace("Like", "") != r_t["name"]: + raise ValueError( + f"Expected miv_simulator.typing.{r_t['name']}Like but found {a_t['name']}" + ) + v = np.array(a, dtype=r_t["dtype"]) + if len(v.shape) == 0: + return np.reshape( + v, + [ + 1, + ], + ) + return v + + +def cast_spike_times(a: SpikeTimesLike) -> SpikeTimes: + return np.sort(_cast(a, SpikeTimesLike, SpikeTimes), axis=0) + + +def cast_binary_sparse_spike_train( + a: BinarySparseSpikeTrainLike, +) -> BinarySparseSpikeTrain: + return _cast(a, BinarySparseSpikeTrainLike, BinarySparseSpikeTrain) diff --git a/tests/test_coding.py b/tests/test_coding.py new file mode 100644 index 0000000..044e51d --- /dev/null +++ b/tests/test_coding.py @@ -0,0 +1,40 @@ +from miv_simulator import coding as t +import numpy as np +import miv_simulator.typing as st + + +def test_coding_spike_times_vs_binary_sparse_spike_train(): + for a, b in [ + ([0.1, 0.3, 0.4, 0.85], [1, 1]), + ([0.8], [0, 1]), + ]: + result = t.spike_times_2_binary_sparse_spike_train(a, 0.5) + expected = np.array(b, dtype=np.int8) + assert np.array_equal(result, expected) + + for a, b in [ + ([1, 0, 1], [0.0, 1.0]), + ([0, 1], [0.5]), + ]: + spike_train = np.array(a, dtype=np.int8) + result = t.binary_sparse_spike_train_2_spike_times(spike_train, 0.5) + expected = np.array(b) + assert np.array_equal(result, expected) + + +def test_coding_adjust_temporal_resolution(): + spike_train = np.array([0, 1, 0, 1, 0], dtype=np.int8) + + # identity + adjusted = t.adjust_temporal_resolution(spike_train, 1, 1) + assert np.array_equal(adjusted, spike_train) + + # up + adjusted = t.adjust_temporal_resolution(spike_train, 0.5, 1) + expected = np.array([0, 0, 1, 1, 0, 0, 1, 1, 0, 0], dtype=np.int8) + assert np.array_equal(adjusted, expected) + + # down + adjusted = t.adjust_temporal_resolution(spike_train, 2, 1) + expected = np.array([1, 1], dtype=np.int8) + assert np.array_equal(adjusted, expected) diff --git a/tests/test_typing.py b/tests/test_typing.py new file mode 100644 index 0000000..694205d --- /dev/null +++ b/tests/test_typing.py @@ -0,0 +1,10 @@ +from miv_simulator import typing as t +import numpy as np + + +def test_typing_cast(): + assert t.cast_spike_times(0.5).shape == (1,) + assert t.cast_spike_times([0.5, 0.1])[1] == 0.5 + assert t.cast_spike_times(int(1))[0] == float(1.0) + + assert t.cast_binary_sparse_spike_train(0.1)[0] == 0