diff --git a/Project.toml b/Project.toml index c047794..3e6cb97 100644 --- a/Project.toml +++ b/Project.toml @@ -47,7 +47,7 @@ SpecialFunctions = "2.4.0" StatsBase = "0.34.3" StatsFuns = "1.3.2" Test = "1.11" -julia = ">=1.11.0" +julia = "1.11.0" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/README.md b/README.md index 232f904..dbf0736 100644 --- a/README.md +++ b/README.md @@ -1,17 +1,34 @@ # StateSpaceAnalysis.jl -[![Build Status](https://github.com/harrisonritz/StateSpaceAnalysis.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/harrisonritz/StateSpaceAnalysis.jl/actions/workflows/CI.yml?query=branch%3Amain) +[![Build Status](https://github.com/harrisonritz/StateSpaceAnalysis.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/harrisonritz/StateSpaceAnalysis.jl/actions/workflows/CI.yml?query=branch%3Amain) [![Aqua QA](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl) ## Overview -StateSpaceAnalysis.jl is a Julia package designed for fitting linear Gaussian state space models (lg-SSMs) using Subspace System Identification (SSID) and Expectation Maximization (EM) algorithms. This package provides tools for preprocessing data, fitting models, and evaluating model performance. +*StateSpaceAnalysis.jl* is a Julia package designed for fitting linear-Gaussian state space models (SSMs) using Subspace System Identification (SSID) and Expectation Maximization (EM) algorithms. + +This package provides tools for preprocessing data, fitting models, and evaluating model performance, with methods especially tailored towards neuroimaging analysis: + + +### Event-related designs + +Neuroimaging data often has epoched/batched sequences (e.g., states x timesteps x trials). *StateSpaceAnalysis.jl* handles epoched data by re-using computations across batches, and it includes spline temporal bases for flexible input modeling over the epoch. + + +### High-dimensional Systems + +Whole-brain modelling may require a large number of latent factors. *StateSpaceAnalysis.jl* handles scaling through efficient memory allocation, robust covariance formats (via [*PDMats.jl*](https://github.com/JuliaStats/PDMats.jl)), and regularization. + + +### Data-driven Initialization + +We need good initialization for systems for which we don't have great domain knowledge (especially when there are many latent factors). *StateSpaceAnalysis.jl* handles parameter initialization through subspace identification methods from [*ControlSystemsIdentification.jl*](https://github.com/baggepinnen/ControlSystemIdentification.jl). This version is very preliminary, so there may be some rough edges! ## Installation -To install the StateSpaceAnalysis.jl package, follow these steps: +To install the *StateSpaceAnalysis.jl* package, follow these steps: 1. **Clone the repository:** ```sh @@ -29,15 +46,112 @@ To install the StateSpaceAnalysis.jl package, follow these steps: 3. **Add the package to your Julia environment:** ```julia Pkg.add(path=".") + using StateSpaceAnalysis ``` This will install all the necessary dependencies and set up the StateSpaceAnalysis.jl package for use. +## Walkthrough of the `example/fit_example.jl` script + +### Set up `S`, the core structure which carries the parameters and data structure + +```julia +S = core_struct( + prm=param_struct( + ... # high-level parameters + ), + dat=data_struct( + ... # data and data description + ), + res=results_struct( + ... # fit metrics and model derivates + ), + est=estimates_struct( + ... # scratch space + ), + mdl=model_struct( + ... # estimated model parameters + ), + ); +``` +This structure is used throughout the script, which allows for effective memory management (i.e., the complier can know the size of the data tensors). + +### Preprocess the data: + +```julia +@reset S = StateSpaceAnalysis.preprocess_fit(S); +``` +Preprocessing steps within `preprocess_fit(S)`: +```julia +# read in arguements, helpful for running on a cluster +S = deepcopy(StateSpaceAnalysis.read_args(S, ARGS)); + +# set up the paths +StateSpaceAnalysis.setup_path(S) + +# load and format the data; split for cross-validation +S = deepcopy(StateSpaceAnalysis.load_data(S)); + +# build the input tenors (e.g., z-score and convolve with basis) +S = deepcopy(StateSpaceAnalysis.build_inputs(S)); + +# transform the observed data (PCA) +S = deepcopy(StateSpaceAnalysis.whiten(S)); + +# fit baseline models to the data +StateSpaceAnalysis.null_loglik!(S); + +# initialize the expectations and parameters +@reset S.est = deepcopy(set_estimates(S)); +@reset S = deepcopy(generate_rand_params(S)); +``` + +### Warm-start the EM with initial parameters from Subspace Identification (SSID): + +```julia +if S.prm.ssid_fit == "fit" # if fitting the SSID + @reset S = StateSpaceAnalysis.launch_SSID(S); +elseif S.prm.ssid_fit == "load" # if loading a previously-fit SSID + @reset S = StateSpaceAnalysis.load_SSID(S); +end +``` + +### Fit the parameters use EM: + +```julia +@reset S = StateSpaceAnalysis.launch_EM(S); +``` +The basic structure of the EM script: +```julia +for em_iter = 1:S.prm.max_iter_em + + # ==== E-STEP ================================================================ + @inline StateSpaceAnalysis.ESTEP!(S); # estimate the sufficient statistics + + # ==== M-STEP ================================================================ + @reset S.mdl = deepcopy(StateSpaceAnalysis.MSTEP(S)); # use the sufficient statistics to update the parameters + + # ==== TOTAL LOGLIK ========================================================== + StateSpaceAnalysis.total_loglik!(S) # compute the total likelihood + + # [...] quality & convergence checks +end +``` + +### Save the fit: + +```julia +StateSpaceAnalysis.save_results(S) +``` + + + ## Functions Overview ### `setup/custom.jl` + **This needs to be set by the user for the project-specific parameters** - `assign_arguments`: Assigns command-line arguments to the structure. - `select_trials`: Selects trials based on custom criteria. @@ -48,6 +162,7 @@ This will install all the necessary dependencies and set up the StateSpaceAnalys - `format_B_postSSID`: Assigns the estimated B columns to the rest of the matrix. ### `fit/launch.jl` + - `preprocess_fit`: Preprocesses the data and sets up the fitting environment. - `launch_SSID`: Launches the SSID fitting process. - `launch_EM`: Launches the EM fitting process. @@ -56,11 +171,13 @@ This will install all the necessary dependencies and set up the StateSpaceAnalys - `save_results`: Saves the fitting results. ### `fit/SSID.jl` -**These function are modifed from the excellent ControlSystemsIdentification.jl package** + +**These function are modifed from [*ControlSystemsIdentification.jl*](https://github.com/baggepinnen/ControlSystemIdentification.jl)** - `fit_SSID`: Performs subspace identification for state space analysis. - `subspaceid_SSA`: modified ControlSystemsIdentification.jl for SSID ### `fit/EM.jl` + - `fit_EM`: Runs the EM algorithm for individual participants. - `ESTEP!`: Executes the E-step of the EM algorithm. - `MSTEP`: Executes the M-step of the EM algorithm. @@ -70,19 +187,27 @@ This will install all the necessary dependencies and set up the StateSpaceAnalys ### `fit/posteriors.jl` + - `posterior_all`: Generates all posterior estimates (mean and covariance). - `posterior_mean`: Generates only the posterior means. - `posterior_sse`: Computes the sum of squared errors for the posteriors. ### `setup/setup.jl` + - `read_args`: process command line arguements (for running on the cluster) - `setup_path`: Sets up the directory paths for saving results. - `load_data`: Loads the data from files. - `build_inputs`: Builds the input matrices for the model. - `whiten`: Whitens the observations (PCA). +### `setup/generate.jl` + +- `gen_rand_params`: generate random SSM parameters +- `generate_ssm_trials`: simulate trials from a set of SSM parameters + ### `setup/structs.jl` + - `param_struct`: Defines the parameters structure. - `data_struct`: Defines the data structure. - `results_struct`: Defines the results structure. @@ -94,109 +219,10 @@ This will install all the necessary dependencies and set up the StateSpaceAnalys - `post_sse`: Defines the structure for posterior sum of squared errors. ### `utils/utils.jl` + - `tol_PD`: Ensures a matrix is positive definite with a tolerance. - `tol_PSD`: Ensures a matrix is positive semi-definite with a tolerance. - `demix`: Demixes the observations using the saved PCA transformation. - `remix`: Remixes the observations using the saved PCA transformation. - - - - - - -## Running the Example -To run the example fitting script, follow these steps: - -1. Set the paths in `example/fit_example.jl`: - ```julia - run_cluster = length(ARGS)!=0; - if run_cluster - src_path = "/home/hr0283/HallM_StateSpaceAnalysis/src" - save_path = "/scratch/gpfs/hr0283/HallM_StateSpaceAnalysis/src"; - else - src_path = "/Users/hr0283/Projects/StateSpaceAnalysis.jl/src" - save_path = "/Users/hr0283/Projects/StateSpaceAnalysis.jl/example"; - end - ``` - -2. Load the necessary packages and configure the system: - ```julia - using StateSpaceAnalysis - using Accessors - using Random - using LinearAlgebra - using Dates - using Revise - ``` - -3. Set the parameters and data structure: - ```julia - S = core_struct( - prm=param_struct( - seed = rand_seed, - model_name = "test", - changelog = "run test", - load_name = "HallMcMaster2019_ITI100-Cue200-ISI400-Trial200_srate@125_filt@0-30", - load_path = "/Users/hr0283/Projects/StateSpaceAnalysis.jl/example/example-data", - pt_list = 1:1, - max_iter_em = 500, - ssid_fit = "fit", - ssid_save = false, - ssid_type = :CVA, - ssid_lag = 24, - ), - dat=data_struct( - sel_event = 2:4, - pt = 1, - x_dim = 24, - basis_name = "bspline", - spline_gap = 5, - ), - res=results_struct(), - est=estimates_struct(), - mdl=model_struct(), - ); - ``` - -4. Run the fitting process: - ```julia - @reset S.res.startTime_all = Dates.format(now(), "mm/dd/yyyy HH:MM:SS"); - println("Starting fit at $(S.res.startTime_all)") - - @reset S = StateSpaceAnalysis.preprocess_fit(S); - - if S.prm.ssid_fit == "fit" - @reset S = StateSpaceAnalysis.launch_SSID(S); - elseif S.prm.ssid_fit == "load" - @reset S = StateSpaceAnalysis.load_SSID(S); - end - - @reset S = StateSpaceAnalysis.launch_EM(S); - - @reset S.res.endTime_all = Dates.format(now(), "mm/dd/yyyy HH:MM:SS"); - println("Finished fit at $(Dates.format(now(), "mm/dd/yyyy HH:MM:SS"))") - ``` - -5. Optionally, plot diagnostics and save the fit: - ```julia - do_plots = false - if do_plots - try - StateSpaceAnalysis.plot_loglik_traces(S) - StateSpaceAnalysis.plot_avg_pred(S) - StateSpaceAnalysis.plot_params(S) - catch - end - end - - if S.prm.do_save - println("\n========== SAVING FIT ==========") - StateSpaceAnalysis.save_results(S) - end - ``` - -This will run the example fitting script, performing SSID and EM fitting on the provided data. - - diff --git a/src/StateSpaceAnalysis.jl b/src/StateSpaceAnalysis.jl index ec66584..365cca4 100644 --- a/src/StateSpaceAnalysis.jl +++ b/src/StateSpaceAnalysis.jl @@ -63,7 +63,7 @@ export core_struct, param_struct, data_struct, results_struct, estimates_struct set_model, transform_model include("setup/generate.jl") -export gen_rand_params, generate_ssm_trials +export generate_rand_params, generate_ssm_trials # utility functions diff --git a/src/fit/wrapper.jl b/src/fit/wrapper.jl index 5ca2709..3b8ffcc 100644 --- a/src/fit/wrapper.jl +++ b/src/fit/wrapper.jl @@ -36,7 +36,7 @@ function preprocess_fit(S) @reset S.est = deepcopy(set_estimates(S)); # init model - @reset S = deepcopy(gen_rand_params(S)); + @reset S = deepcopy(generate_rand_params(S)); # ======================================================================= diff --git a/src/setup/generate.jl b/src/setup/generate.jl index d853e72..b51c80a 100644 --- a/src/setup/generate.jl +++ b/src/setup/generate.jl @@ -2,7 +2,7 @@ # init random parameters -function gen_rand_params(S) +function generate_rand_params(S) A = Matrix(Diagonal(rand(S.dat.x_dim))); B = randn(S.dat.x_dim, S.dat.u_dim);