StateSpaceAnalysis.jl is a Julia package designed for fitting linear-Gaussian state space models (SSMs) using Subspace System Identification (SSID) and Expectation Maximization (EM) algorithms.
This package provides tools for preprocessing data, fitting models, and evaluating model performance, with methods especially tailored towards neuroimaging analysis:
Neuroimaging data often has epoched/batched sequences (e.g., states x timesteps x trials). StateSpaceAnalysis.jl handles epoched data by reusing computations across batches, and it includes spline temporal bases for flexible input modeling over the epoch.
Whole-brain modelling may require a large number of latent factors. StateSpaceAnalysis.jl handles scaling through efficient memory allocation, robust covariance formats (via PDMats.jl), and regularization.
We need good initialization for systems for which we don't have great domain knowledge (especially when there are many latent factors). StateSpaceAnalysis.jl handles parameter initialization through subspace identification methods adapted from ControlSystemsIdentification.jl.
This version is very preliminary, so there may be some rough edges!
You can easily install the current release of StateSpaceAnalysis.jl from the Julia General Registry:
```julia
using Pkg
Pkg.add("StateSpaceAnalysis")
```
You may want to work directly with the package, e.g., to modify custom functions for setting up your input bases. You can create a local copy by cloning the github repo:
-
Clone the repository:
git clone https://github.com/harrisonritz/StateSpaceAnalysis.jl.git cd StateSpaceAnalysis.jl
-
Open Julia in the folder and activate the package environment:
using Pkg Pkg.activate(".") Pkg.instantiate()
-
Add the package to your Julia environment:
Pkg.add(path=".") using StateSpaceAnalysis
This will install all the necessary dependencies and set up the StateSpaceAnalysis.jl package for local use.
Note: You can check which directory you are working in with pwd()
in Julia. Opening a folder in VS code sets that folder to your path. You can specify the paths in Pkg.activate("path/to/package")
and Pkg.add("path/to/package")
even in you aren't in the right folder.
S = core_struct(
prm=param_struct(
... # high-level parameters
),
dat=data_struct(
... # data and data description
),
res=results_struct(
... # fit metrics and model derivates
),
est=estimates_struct(
... # scratch space
),
mdl=model_struct(
... # estimated model parameters
),
fnc=function_struct{core_struct}(
... # custom functions for setup
)
);
This structure is used throughout the script, which allows for effective memory management (i.e., the complier can know the size of the data tensors).
@reset S = StateSpaceAnalysis.preprocess_fit(S);
Preprocessing steps within preprocess_fit(S)
:
# read in arguments, helpful for running on a cluster
S = deepcopy(StateSpaceAnalysis.read_args(S, ARGS));
# set up the paths
StateSpaceAnalysis.setup_path(S)
# load and format the data; split for cross-validation
S = deepcopy(StateSpaceAnalysis.load_data(S));
# build the input tenors (e.g., z-score and convolve with basis)
S = deepcopy(StateSpaceAnalysis.build_inputs(S));
# transform the observed data (PCA)
S = deepcopy(StateSpaceAnalysis.project(S));
# fit baseline models to the data
StateSpaceAnalysis.null_loglik!(S);
# initialize the expectations and parameters
@reset S.est = deepcopy(set_estimates(S));
@reset S = deepcopy(generate_rand_params(S));
These preprocessing steps depend on custom code that you can modify.
if S.prm.ssid_fit == "fit" # if fitting the SSID
@reset S = StateSpaceAnalysis.launch_SSID(S);
elseif S.prm.ssid_fit == "load" # if loading a previously-fit SSID
@reset S = StateSpaceAnalysis.load_SSID(S);
end
@reset S = StateSpaceAnalysis.launch_EM(S);
The basic structure of the EM script:
for em_iter = 1:S.prm.max_iter_em
# ==== E-STEP ================================================================
@inline StateSpaceAnalysis.ESTEP!(S); # estimate the sufficient statistics
# ==== M-STEP ================================================================
@reset S.mdl = deepcopy(StateSpaceAnalysis.MSTEP(S)); # use the sufficient statistics to update the parameters
# ==== TOTAL LOGLIK ==========================================================
StateSpaceAnalysis.total_loglik!(S) # compute the total likelihood
# [...] quality & convergence checks
end
StateSpaceAnalysis.save_results(S)
This needs to be set by the user for the project-specific parameters
assign_arguments
: Assigns command-line arguments to the structure.select_trials
: Selects trials based on custom criteria.scale_input
: Scales the input data.create_input_basis
: Formats inputs with basis functions.transform_observations
: Transforms observations, typically using PCA.format_B_preSSID
: Formats the B matrix for SSID.format_B_postSSID
: Assigns the estimated B columns to the rest of the matrix.
preprocess_fit
: Preprocesses the data and sets up the fitting environment.launch_SSID
: Launches the SSID fitting process.launch_EM
: Launches the EM fitting process.load_SSID
: Loads a previously saved SSID model.save_SSID
: Saves the SSID model.save_results
: Saves the fitting results.
These function are modifed from ControlSystemsIdentification.jl
fit_SSID
: Performs subspace identification for state space analysis.subspaceid_SSA
: modified ControlSystemsIdentification.jl for SSID
fit_EM
: Runs the EM algorithm for individual participants.ESTEP!
: Executes the E-step of the EM algorithm.MSTEP
: Executes the M-step of the EM algorithm.estimate_cov!
: Estimates the latent covariance.estimate_mean!
: Estimates the latent mean.estimate_moments!
: update the sufficient statistics.
posterior_all
: Generates all posterior estimates (mean and covariance).posterior_mean
: Generates only the posterior means.posterior_sse
: Computes the sum of squared errors for the posteriors.
read_args
: process command line arguments (for running on the cluster)setup_path
: Sets up the directory paths for saving results.load_data
: Loads the data from files.build_inputs
: Builds the input matrices for the model.project
: projects the observations (PCA).
gen_rand_params
: generate random SSM parametersgenerate_ssm_trials
: simulate trials from a set of SSM parameters
param_struct
: Defines the parameters structure.data_struct
: Defines the data structure.results_struct
: Defines the results structure.estimates_struct
: Defines the estimates structure.model_struct
: Defines the model structure.core_struct
: Combines all the structures into a core structure.post_all
: Defines the structure for all posterior estimates.post_mean
: Defines the structure for posterior means.post_sse
: Defines the structure for posterior sum of squared errors.
tol_PD
: Ensures a matrix is positive definite with a tolerance.tol_PSD
: Ensures a matrix is positive semi-definite with a tolerance.demix
: Demixes the observations using the saved PCA transformation.remix
: Remixes the observations using the saved PCA transformation.