Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Integrate trials object with GPFA #610

Merged
merged 56 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from 51 commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
ce3db6b
add trial class to docstrings
Dec 8, 2023
e9ed6d0
add first input checks for Trials object
Dec 8, 2023
a75dd92
fix input checks
Dec 8, 2023
734a6b1
add trial handling to GPFA.fit() method
Dec 8, 2023
a97f57f
add trial handling to transform method
Dec 8, 2023
99bcd54
add docstring to GPFA.score
Dec 8, 2023
f24ac10
add notebook for testing
Dec 8, 2023
a16f552
update test notebook
Dec 8, 2023
c537561
fix pep8
Dec 14, 2023
89d0ea7
use setupclass for faster unittests
Jan 24, 2024
40ece76
optimize imports for GPFA unitests
Jan 24, 2024
c55f5c6
remove python 2 support
Jan 24, 2024
5386e95
use f-strings for error messages
Jan 24, 2024
165fdfc
implement trial handling for fit and transform, and add unittests
Jan 24, 2024
b8e5192
fix input checks, added test
Jan 24, 2024
59aa188
update docstrings
Jan 24, 2024
61adc7b
fix bug related to scipy
Jan 24, 2024
b77fa27
fix type annotation
Jan 24, 2024
b2b4116
change type annotation
Jan 24, 2024
016baa6
update typing
Jan 24, 2024
5772a6e
fix imports
Jan 24, 2024
5be4580
fix imports
Jan 24, 2024
6559a7d
add type annotations to fit_transform and score
Jan 24, 2024
172dc6d
update tutorial
Feb 7, 2024
adbd68f
format gpfa.py with black
Feb 7, 2024
77c373a
reordered functions
Feb 13, 2024
9baf45d
fix typos
Feb 13, 2024
70dbb6f
undo fix for scipy
Feb 13, 2024
31b1687
fix pep8
Feb 21, 2024
ac8a45d
add decorators to handle trial object
Feb 21, 2024
54dd534
remove type check
Feb 21, 2024
5e08661
add doctring for decorator
Feb 21, 2024
39e73f1
add functions for spiketrains
Feb 21, 2024
6d573b6
formatting
Feb 21, 2024
21373b9
Fix docstring formatting
Feb 28, 2024
f0c0f5d
add condition for no spiketrains
Feb 28, 2024
729be77
change import of sklearn for tests
Mar 18, 2024
8f621fe
fix pep8
Mar 18, 2024
292c3db
remove logic from decorator
Mar 20, 2024
77dfc50
add type annotations to fit method
Mar 20, 2024
ce3d679
add type annotations to transform method
Mar 20, 2024
7b6684a
update tutorial
Mar 20, 2024
e00d880
pep8 formatting
Mar 20, 2024
efe8195
remove development notebook
Mar 21, 2024
43a0c4c
add decorator to score and fit_transform
Mar 25, 2024
61f04d9
make decorator iterate over args and kwargs an replace relevant item
Mar 25, 2024
6009a5a
Merge branch 'master' into feature/trials_gpfa
Mar 25, 2024
c303434
add unit tests for new decorator
Mar 25, 2024
8b0caf6
fix unittest
Mar 25, 2024
133f8ef
use variable n_trials instead of hard-coding this number
Mar 25, 2024
f62e1de
update comments
Mar 26, 2024
af0481a
remove self argument from decorator
Mar 26, 2024
e4fc728
fix docstring for decorator
Mar 26, 2024
ed33527
fix pep8
Mar 26, 2024
9c4cae2
fix typo in example for decorator
Mar 26, 2024
f993f8c
fix coverage report
Mar 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
287 changes: 182 additions & 105 deletions elephant/gpfa/gpfa.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,19 +69,18 @@
:license: Modified BSD, see LICENSE.txt for details.
"""

from __future__ import division, print_function, unicode_literals

from typing import List, Union
import neo
import numpy as np
import quantities as pq
import sklearn

from elephant.gpfa import gpfa_core, gpfa_util
from elephant.trials import Trials
from elephant.utils import trials_to_list_of_spiketrainlist


__all__ = [
"GPFA"
]
__all__ = ["GPFA"]


class GPFA(sklearn.base.BaseEstimator):
Expand Down Expand Up @@ -228,9 +227,18 @@ class GPFA(sklearn.base.BaseEstimator):
... 'latent_variable'])
"""

def __init__(self, bin_size=20 * pq.ms, x_dim=3, min_var_frac=0.01,
tau_init=100.0 * pq.ms, eps_init=1.0E-3, em_tol=1.0E-8,
em_max_iters=500, freq_ll=5, verbose=False):
def __init__(
self,
bin_size=20 * pq.ms,
x_dim=3,
min_var_frac=0.01,
tau_init=100.0 * pq.ms,
eps_init=1.0e-3,
em_tol=1.0e-8,
em_max_iters=500,
freq_ll=5,
verbose=False,
):
# Initialize object
self.bin_size = bin_size
self.x_dim = x_dim
Expand All @@ -241,11 +249,12 @@ def __init__(self, bin_size=20 * pq.ms, x_dim=3, min_var_frac=0.01,
self.em_max_iters = em_max_iters
self.freq_ll = freq_ll
self.valid_data_names = (
'latent_variable_orth',
'latent_variable',
'Vsm',
'VsmGP',
'y')
"latent_variable_orth",
"latent_variable",
"Vsm",
"VsmGP",
"y",
)
self.verbose = verbose

if not isinstance(self.bin_size, pq.Quantity):
Expand All @@ -258,17 +267,53 @@ def __init__(self, bin_size=20 * pq.ms, x_dim=3, min_var_frac=0.01,
self.fit_info = dict()
self.transform_info = dict()

def fit(self, spiketrains):
@staticmethod
def _check_training_data(
spiketrains: List[List[neo.core.SpikeTrain]],
) -> None:
if len(spiketrains) == 0:
raise ValueError("Input spiketrains can not be empty")
if not all(
isinstance(item, neo.SpikeTrain)
for sublist in spiketrains
for item in sublist
):
raise ValueError(
"structure of the spiketrains is not "
"correct: 0-axis should be trials, 1-axis "
"neo.SpikeTrain and 2-axis spike times."
)

def _format_training_data(
self, spiketrains: List[List[neo.core.SpikeTrain]]
) -> np.recarray:
seqs = gpfa_util.get_seqs(spiketrains, self.bin_size)
# Remove inactive units based on training set
self.has_spikes_bool = np.hstack(seqs["y"]).any(axis=1)
for seq in seqs:
seq["y"] = seq["y"][self.has_spikes_bool, :]
return seqs

@trials_to_list_of_spiketrainlist
def fit(
self,
spiketrains: Union[
List[List[neo.core.SpikeTrain]],
"Trials",
List[neo.core.spiketrainlist.SpikeTrainList],
],
) -> "GPFA":
"""
Fit the model with the given training data.

Parameters
----------
spiketrains : list of list of neo.SpikeTrain
---------- # noqa
spiketrains : :class:`elephant.trials.Trials`, list of :class:`neo.core.spiketrainlist.SpikeTrainList` or list of list of :class:`neo.core.SpikeTrain`
Spike train data to be fit to latent variables.
The outer list corresponds to trials and the inner list corresponds
to the neurons recorded in that trial, such that
`spiketrains[l][n]` is the spike train of neuron `n` in trial `l`.
For list of lists, the outer list corresponds to trials and the
inner list corresponds to the neurons recorded in that trial, such
that `spiketrains[l][n]` is the spike train of neuron `n` in trial
`l`.
Note that the number and order of `neo.SpikeTrain` objects per
trial must be fixed such that `spiketrains[l][n]` and
`spiketrains[k][n]` refer to spike trains of the same neuron
Expand All @@ -288,69 +333,74 @@ def fit(self, spiketrains):

If covariance matrix of input spike data is rank deficient.
"""
self._check_training_data(spiketrains)
seqs_train = self._format_training_data(spiketrains)
# Check if training data covariance is full rank
y_all = np.hstack(seqs_train['y'])
y_dim = y_all.shape[0]

if np.linalg.matrix_rank(np.cov(y_all)) < y_dim:
errmesg = 'Observation covariance matrix is rank deficient.\n' \
'Possible causes: ' \
'repeated units, not enough observations.'
raise ValueError(errmesg)

if self.verbose:
print('Number of training trials: {}'.format(len(seqs_train)))
print('Latent space dimensionality: {}'.format(self.x_dim))
print('Observation dimensionality: {}'.format(
self.has_spikes_bool.sum()))

# The following does the heavy lifting.
self.params_estimated, self.fit_info = gpfa_core.fit(
seqs_train=seqs_train,
x_dim=self.x_dim,
bin_width=self.bin_size.rescale('ms').magnitude,
min_var_frac=self.min_var_frac,
em_max_iters=self.em_max_iters,
em_tol=self.em_tol,
tau_init=self.tau_init.rescale('ms').magnitude,
eps_init=self.eps_init,
freq_ll=self.freq_ll,
verbose=self.verbose)

return self

@staticmethod
def _check_training_data(spiketrains):
if len(spiketrains) == 0:
raise ValueError("Input spiketrains cannot be empty")
if not isinstance(spiketrains[0][0], neo.SpikeTrain):
raise ValueError("structure of the spiketrains is not correct: "
"0-axis should be trials, 1-axis neo.SpikeTrain"
"and 2-axis spike times")

def _format_training_data(self, spiketrains):
seqs = gpfa_util.get_seqs(spiketrains, self.bin_size)
# Remove inactive units based on training set
self.has_spikes_bool = np.hstack(seqs['y']).any(axis=1)
for seq in seqs:
seq['y'] = seq['y'][self.has_spikes_bool, :]
return seqs

def transform(self, spiketrains, returned_data=['latent_variable_orth']):
if all(
isinstance(item, neo.SpikeTrain)
for sublist in spiketrains
for item in sublist
):
self._check_training_data(spiketrains)
seqs_train = self._format_training_data(spiketrains)
# Check if training data covariance is full rank
y_all = np.hstack(seqs_train["y"])
y_dim = y_all.shape[0]

if np.linalg.matrix_rank(np.cov(y_all)) < y_dim:
errmesg = (
"Observation covariance matrix is rank deficient.\n"
"Possible causes: "
"repeated units, not enough observations."
)
raise ValueError(errmesg)

if self.verbose:
print("Number of training trials: {}".format(len(seqs_train)))
print("Latent space dimensionality: {}".format(self.x_dim))
print(
"Observation dimensionality: {}".format(
self.has_spikes_bool.sum()
)
)

# The following does the heavy lifting.
self.params_estimated, self.fit_info = gpfa_core.fit(
seqs_train=seqs_train,
x_dim=self.x_dim,
bin_width=self.bin_size.rescale("ms").magnitude,
min_var_frac=self.min_var_frac,
em_max_iters=self.em_max_iters,
em_tol=self.em_tol,
tau_init=self.tau_init.rescale("ms").magnitude,
eps_init=self.eps_init,
freq_ll=self.freq_ll,
verbose=self.verbose,
)
return self
else: # TODO: implement case for continuous data
raise ValueError

@trials_to_list_of_spiketrainlist
def transform(
self,
spiketrains: Union[
List[List[neo.core.SpikeTrain]],
"Trials",
List[neo.core.spiketrainlist.SpikeTrainList],
],
returned_data: str = ["latent_variable_orth"],
) -> "GPFA":
"""
Obtain trajectories of neural activity in a low-dimensional latent
variable space by inferring the posterior mean of the obtained GPFA
model and applying an orthonormalization on the latent variable space.

Parameters
----------
spiketrains : list of list of neo.SpikeTrain
---------- # noqa
spiketrains : list of list of :class:`neo.core.SpikeTrain`, list of :class:`neo.core.spiketrainlist.SpikeTrainList` or :class:`elephant.trials.Trials`
Spike train data to be transformed to latent variables.
The outer list corresponds to trials and the inner list corresponds
to the neurons recorded in that trial, such that
`spiketrains[l][n]` is the spike train of neuron `n` in trial `l`.
For list of lists, the outer list corresponds to trials and the
inner list corresponds to the neurons recorded in that trial, such
that `spiketrains[l][n]` is the spike train of neuron `n` in trial
`l`.
Note that the number and order of `neo.SpikeTrain` objects per
trial must be fixed such that `spiketrains[l][n]` and
`spiketrains[k][n]` refer to spike trains of the same neuron
Expand Down Expand Up @@ -378,7 +428,7 @@ def transform(self, spiketrains, returned_data=['latent_variable_orth']):

Returns
-------
np.ndarray or dict
:class:`np.ndarray` or dict
When the length of `returned_data` is one, a single np.ndarray,
containing the requested data (the first entry in `returned_data`
keys list), is returned. Otherwise, a dict of multiple np.ndarrays
Expand Down Expand Up @@ -411,36 +461,55 @@ def transform(self, spiketrains, returned_data=['latent_variable_orth']):
If `returned_data` contains keys different from the ones in
`self.valid_data_names`.
"""
if len(spiketrains[0]) != len(self.has_spikes_bool):
raise ValueError("'spiketrains' must contain the same number of "
"neurons as the training spiketrain data")
invalid_keys = set(returned_data).difference(self.valid_data_names)
if len(invalid_keys) > 0:
raise ValueError("'returned_data' can only have the following "
"entries: {}".format(self.valid_data_names))
seqs = gpfa_util.get_seqs(spiketrains, self.bin_size)
for seq in seqs:
seq['y'] = seq['y'][self.has_spikes_bool, :]
seqs, ll = gpfa_core.exact_inference_with_ll(seqs,
self.params_estimated,
get_ll=True)
self.transform_info['log_likelihood'] = ll
self.transform_info['num_bins'] = seqs['T']
Corth, seqs = gpfa_core.orthonormalize(self.params_estimated, seqs)
self.transform_info['Corth'] = Corth
if len(returned_data) == 1:
return seqs[returned_data[0]]
return {x: seqs[x] for x in returned_data}

def fit_transform(self, spiketrains, returned_data=[
'latent_variable_orth']):
if all(
isinstance(item, neo.SpikeTrain)
for sublist in spiketrains
for item in sublist
):
if len(spiketrains[0]) != len(self.has_spikes_bool):
raise ValueError(
"'spiketrains' must contain the same number of "
"neurons as the training spiketrain data"
)
invalid_keys = set(returned_data).difference(self.valid_data_names)
if len(invalid_keys) > 0:
raise ValueError(
"'returned_data' can only have the following "
f"entries: {self.valid_data_names}"
)
seqs = gpfa_util.get_seqs(spiketrains, self.bin_size)
for seq in seqs:
seq["y"] = seq["y"][self.has_spikes_bool, :]
seqs, ll = gpfa_core.exact_inference_with_ll(
seqs, self.params_estimated, get_ll=True
)
self.transform_info["log_likelihood"] = ll
self.transform_info["num_bins"] = seqs["T"]
Corth, seqs = gpfa_core.orthonormalize(self.params_estimated, seqs)
self.transform_info["Corth"] = Corth
if len(returned_data) == 1:
return seqs[returned_data[0]]
return {x: seqs[x] for x in returned_data}
else: # TODO: implement case for continuous data
raise ValueError

@trials_to_list_of_spiketrainlist
def fit_transform(
self,
spiketrains: Union[
List[List[neo.core.SpikeTrain]],
"Trials",
List[neo.core.spiketrainlist.SpikeTrainList],
],
returned_data: str = ["latent_variable_orth"],
) -> "GPFA":
"""
Fit the model with `spiketrains` data and apply the dimensionality
reduction on `spiketrains`.

Parameters
----------
spiketrains : list of list of neo.SpikeTrain
---------- # noqa
spiketrains : list of list of :class:`neo.core.SpikeTrain`, list of :class:`neo.core.spiketrainlist.SpikeTrainList` or :class:`elephant.trials.Trials`
Refer to the :func:`GPFA.fit` docstring.

returned_data : list of str
Expand All @@ -465,13 +534,21 @@ def fit_transform(self, spiketrains, returned_data=[
self.fit(spiketrains)
return self.transform(spiketrains, returned_data=returned_data)

def score(self, spiketrains):
@trials_to_list_of_spiketrainlist
def score(
self,
spiketrains: Union[
List[List[neo.core.SpikeTrain]],
"Trials",
List[neo.core.spiketrainlist.SpikeTrainList],
],
) -> "GPFA":
"""
Returns the log-likelihood of the given data under the fitted model

Parameters
----------
spiketrains : list of list of neo.SpikeTrain
---------- # noqa
spiketrains : list of list of :class:`neo.core.SpikeTrain`, list of :class:`neo.core.spiketrainlist.SpikeTrainList` or :class:`elephant.trials.Trials`
Spike train data to be scored.
The outer list corresponds to trials and the inner list corresponds
to the neurons recorded in that trial, such that
Expand All @@ -487,4 +564,4 @@ def score(self, spiketrains):
Log-likelihood of the given spiketrains under the fitted model.
"""
self.transform(spiketrains)
return self.transform_info['log_likelihood']
return self.transform_info["log_likelihood"]
Loading
Loading