Skip to content

Commit

Permalink
Add spike train types
Browse files Browse the repository at this point in the history
  • Loading branch information
frthjf committed Sep 16, 2023
1 parent eea80cb commit e64b280
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 0 deletions.
52 changes: 52 additions & 0 deletions src/miv_simulator/coding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from miv_simulator import typing as st
import numpy as np


def spike_times_2_binary_sparse_spike_train(
array: st.SpikeTimesLike, temporal_resolution: float
) -> st.BinarySparseSpikeTrain:
a = st.cast_spike_times(array)
bins = np.floor(a / temporal_resolution).astype(int)
# since a is sorted, maximum is last value
spike_train = np.zeros(bins[-1] + 1, dtype=np.int8)
spike_train[bins] = 1
return spike_train


def binary_sparse_spike_train_2_spike_times(
array: st.BinarySparseSpikeTrainLike, temporal_resolution: float
) -> st.SpikeTimes:
a = st.cast_binary_sparse_spike_train(array)
spike_indices = np.where(a == 1)[0]
spike_times = spike_indices * temporal_resolution
return spike_times


def adjust_temporal_resolution(
array: st.BinarySparseSpikeTrainLike,
original_resolution: float,
target_resolution: float,
) -> st.BinarySparseSpikeTrain:
a = st.cast_binary_sparse_spike_train(array)

ratio = target_resolution / original_resolution
if ratio == 1:
return a

new_length = int(a.shape[0] * ratio)
new_spike_train = np.zeros(new_length, dtype=np.int8)

# up
if ratio > 1:
for idx, val in enumerate(a):
start = int(idx * ratio)
end = int((idx + 1) * ratio)
new_spike_train[start:end] = val

# down
elif ratio < 1:
for idx in range(0, len(a), int(1 / ratio)):
if np.any(a[idx : idx + int(1 / ratio)]):
new_spike_train[idx // int(1 / ratio)] = 1

return new_spike_train
60 changes: 60 additions & 0 deletions src/miv_simulator/typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from numpy.typing import NDArray
import numpy as np
from typing import Annotated as EventArray, Dict

"""Potentially unsorted or scalar data that can be transformed into `SpikeTimes`"""
SpikeTimesLike = EventArray[NDArray[np.float_], "SpikeTimesLike ..."]

"""Sorted array of absolute spike times"""
SpikeTimes = EventArray[NDArray[np.float_], "SpikeTimes T ..."]

# spike train encodings (RLE, delta encoding, variable time binning etc.)

"""Binary data that can be cast to the `BinarySparseSpikeTrain` format"""
BinarySparseSpikeTrainLike = EventArray[
NDArray, "BinarySparseSpikeTrainLike ..."
]

"""Binary spike train representation for a given temporal resolution"""
BinarySparseSpikeTrain = EventArray[
NDArray[np.int8], "BinarySparseSpikeTrain t_bin ..."
]


def _inspect(type_) -> Dict:
annotation = type_.__metadata__[0]
name, *dims = annotation.split(" ")

return {
"annotation": annotation,
"name": name,
"dims": dims,
"dtype": type_.__origin__.__args__[1].__args__[0],
}


def _cast(a, a_type, r_type): # -> r_type
a_t, r_t = _inspect(a_type), _inspect(r_type)
if a_t["name"].replace("Like", "") != r_t["name"]:
raise ValueError(
f"Expected miv_simulator.typing.{r_t['name']}Like but found {a_t['name']}"
)
v = np.array(a, dtype=r_t["dtype"])
if len(v.shape) == 0:
return np.reshape(
v,
[
1,
],
)
return v


def cast_spike_times(a: SpikeTimesLike) -> SpikeTimes:
return np.sort(_cast(a, SpikeTimesLike, SpikeTimes), axis=0)


def cast_binary_sparse_spike_train(
a: BinarySparseSpikeTrainLike,
) -> BinarySparseSpikeTrain:
return _cast(a, BinarySparseSpikeTrainLike, BinarySparseSpikeTrain)
40 changes: 40 additions & 0 deletions tests/test_coding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from miv_simulator import coding as t
import numpy as np
import miv_simulator.typing as st


def test_coding_spike_times_vs_binary_sparse_spike_train():
for a, b in [
([0.1, 0.3, 0.4, 0.85], [1, 1]),
([0.8], [0, 1]),
]:
result = t.spike_times_2_binary_sparse_spike_train(a, 0.5)
expected = np.array(b, dtype=np.int8)
assert np.array_equal(result, expected)

for a, b in [
([1, 0, 1], [0.0, 1.0]),
([0, 1], [0.5]),
]:
spike_train = np.array(a, dtype=np.int8)
result = t.binary_sparse_spike_train_2_spike_times(spike_train, 0.5)
expected = np.array(b)
assert np.array_equal(result, expected)


def test_coding_adjust_temporal_resolution():
spike_train = np.array([0, 1, 0, 1, 0], dtype=np.int8)

# identity
adjusted = t.adjust_temporal_resolution(spike_train, 1, 1)
assert np.array_equal(adjusted, spike_train)

# up
adjusted = t.adjust_temporal_resolution(spike_train, 0.5, 1)
expected = np.array([0, 0, 1, 1, 0, 0, 1, 1, 0, 0], dtype=np.int8)
assert np.array_equal(adjusted, expected)

# down
adjusted = t.adjust_temporal_resolution(spike_train, 2, 1)
expected = np.array([1, 1], dtype=np.int8)
assert np.array_equal(adjusted, expected)
10 changes: 10 additions & 0 deletions tests/test_typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from miv_simulator import typing as t
import numpy as np


def test_typing_cast():
assert t.cast_spike_times(0.5).shape == (1,)
assert t.cast_spike_times([0.5, 0.1])[1] == 0.5
assert t.cast_spike_times(int(1))[0] == float(1.0)

assert t.cast_binary_sparse_spike_train(0.1)[0] == 0

0 comments on commit e64b280

Please sign in to comment.