Skip to content

Commit

Permalink
Merge pull request #3 from harrisonritz/dev
Browse files Browse the repository at this point in the history
readme & compat
  • Loading branch information
harrisonritz authored Dec 4, 2024
2 parents 3dfdee9 + f7251bd commit 136c0b9
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 108 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ SpecialFunctions = "2.4.0"
StatsBase = "0.34.3"
StatsFuns = "1.3.2"
Test = "1.11"
julia = ">=1.11.0"
julia = "1.11.0"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
234 changes: 130 additions & 104 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,17 +1,34 @@
# StateSpaceAnalysis.jl

[![Build Status](https://github.com/harrisonritz/StateSpaceAnalysis.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/harrisonritz/StateSpaceAnalysis.jl/actions/workflows/CI.yml?query=branch%3Amain)
[![Build Status](https://github.com/harrisonritz/StateSpaceAnalysis.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/harrisonritz/StateSpaceAnalysis.jl/actions/workflows/CI.yml?query=branch%3Amain) [![Aqua QA](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl)


## Overview

StateSpaceAnalysis.jl is a Julia package designed for fitting linear Gaussian state space models (lg-SSMs) using Subspace System Identification (SSID) and Expectation Maximization (EM) algorithms. This package provides tools for preprocessing data, fitting models, and evaluating model performance.
*StateSpaceAnalysis.jl* is a Julia package designed for fitting linear-Gaussian state space models (SSMs) using Subspace System Identification (SSID) and Expectation Maximization (EM) algorithms.

This package provides tools for preprocessing data, fitting models, and evaluating model performance, with methods especially tailored towards neuroimaging analysis:


### Event-related designs

Neuroimaging data often has epoched/batched sequences (e.g., states x timesteps x trials). *StateSpaceAnalysis.jl* handles epoched data by re-using computations across batches, and it includes spline temporal bases for flexible input modeling over the epoch.


### High-dimensional Systems

Whole-brain modelling may require a large number of latent factors. *StateSpaceAnalysis.jl* handles scaling through efficient memory allocation, robust covariance formats (via [*PDMats.jl*](https://github.com/JuliaStats/PDMats.jl)), and regularization.


### Data-driven Initialization

We need good initialization for systems for which we don't have great domain knowledge (especially when there are many latent factors). *StateSpaceAnalysis.jl* handles parameter initialization through subspace identification methods from [*ControlSystemsIdentification.jl*](https://github.com/baggepinnen/ControlSystemIdentification.jl).

This version is very preliminary, so there may be some rough edges!

## Installation

To install the StateSpaceAnalysis.jl package, follow these steps:
To install the *StateSpaceAnalysis.jl* package, follow these steps:

1. **Clone the repository:**
```sh
Expand All @@ -29,15 +46,112 @@ To install the StateSpaceAnalysis.jl package, follow these steps:
3. **Add the package to your Julia environment:**
```julia
Pkg.add(path=".")
using StateSpaceAnalysis
```

This will install all the necessary dependencies and set up the StateSpaceAnalysis.jl package for use.


## Walkthrough of the `example/fit_example.jl` script

### Set up `S`, the core structure which carries the parameters and data structure

```julia
S = core_struct(
prm=param_struct(
... # high-level parameters
),
dat=data_struct(
... # data and data description
),
res=results_struct(
... # fit metrics and model derivates
),
est=estimates_struct(
... # scratch space
),
mdl=model_struct(
... # estimated model parameters
),
);
```
This structure is used throughout the script, which allows for effective memory management (i.e., the complier can know the size of the data tensors).

### Preprocess the data:

```julia
@reset S = StateSpaceAnalysis.preprocess_fit(S);
```
Preprocessing steps within `preprocess_fit(S)`:
```julia
# read in arguements, helpful for running on a cluster
S = deepcopy(StateSpaceAnalysis.read_args(S, ARGS));
# set up the paths
StateSpaceAnalysis.setup_path(S)
# load and format the data; split for cross-validation
S = deepcopy(StateSpaceAnalysis.load_data(S));
# build the input tenors (e.g., z-score and convolve with basis)
S = deepcopy(StateSpaceAnalysis.build_inputs(S));
# transform the observed data (PCA)
S = deepcopy(StateSpaceAnalysis.whiten(S));
# fit baseline models to the data
StateSpaceAnalysis.null_loglik!(S);
# initialize the expectations and parameters
@reset S.est = deepcopy(set_estimates(S));
@reset S = deepcopy(generate_rand_params(S));
```

### Warm-start the EM with initial parameters from Subspace Identification (SSID):

```julia
if S.prm.ssid_fit == "fit" # if fitting the SSID
@reset S = StateSpaceAnalysis.launch_SSID(S);
elseif S.prm.ssid_fit == "load" # if loading a previously-fit SSID
@reset S = StateSpaceAnalysis.load_SSID(S);
end
```
### Fit the parameters use EM:
```julia
@reset S = StateSpaceAnalysis.launch_EM(S);
```
The basic structure of the EM script:
```julia
for em_iter = 1:S.prm.max_iter_em
# ==== E-STEP ================================================================
@inline StateSpaceAnalysis.ESTEP!(S); # estimate the sufficient statistics
# ==== M-STEP ================================================================
@reset S.mdl = deepcopy(StateSpaceAnalysis.MSTEP(S)); # use the sufficient statistics to update the parameters
# ==== TOTAL LOGLIK ==========================================================
StateSpaceAnalysis.total_loglik!(S) # compute the total likelihood
# [...] quality & convergence checks
end
```
### Save the fit:
```julia
StateSpaceAnalysis.save_results(S)
```
## Functions Overview
### `setup/custom.jl`
**This needs to be set by the user for the project-specific parameters**
- `assign_arguments`: Assigns command-line arguments to the structure.
- `select_trials`: Selects trials based on custom criteria.
Expand All @@ -48,6 +162,7 @@ This will install all the necessary dependencies and set up the StateSpaceAnalys
- `format_B_postSSID`: Assigns the estimated B columns to the rest of the matrix.
### `fit/launch.jl`
- `preprocess_fit`: Preprocesses the data and sets up the fitting environment.
- `launch_SSID`: Launches the SSID fitting process.
- `launch_EM`: Launches the EM fitting process.
Expand All @@ -56,11 +171,13 @@ This will install all the necessary dependencies and set up the StateSpaceAnalys
- `save_results`: Saves the fitting results.
### `fit/SSID.jl`
**These function are modifed from the excellent ControlSystemsIdentification.jl package**
**These function are modifed from [*ControlSystemsIdentification.jl*](https://github.com/baggepinnen/ControlSystemIdentification.jl)**
- `fit_SSID`: Performs subspace identification for state space analysis.
- `subspaceid_SSA`: modified ControlSystemsIdentification.jl for SSID
### `fit/EM.jl`
- `fit_EM`: Runs the EM algorithm for individual participants.
- `ESTEP!`: Executes the E-step of the EM algorithm.
- `MSTEP`: Executes the M-step of the EM algorithm.
Expand All @@ -70,19 +187,27 @@ This will install all the necessary dependencies and set up the StateSpaceAnalys
### `fit/posteriors.jl`
- `posterior_all`: Generates all posterior estimates (mean and covariance).
- `posterior_mean`: Generates only the posterior means.
- `posterior_sse`: Computes the sum of squared errors for the posteriors.
### `setup/setup.jl`
- `read_args`: process command line arguements (for running on the cluster)
- `setup_path`: Sets up the directory paths for saving results.
- `load_data`: Loads the data from files.
- `build_inputs`: Builds the input matrices for the model.
- `whiten`: Whitens the observations (PCA).
### `setup/generate.jl`
- `gen_rand_params`: generate random SSM parameters
- `generate_ssm_trials`: simulate trials from a set of SSM parameters
### `setup/structs.jl`
- `param_struct`: Defines the parameters structure.
- `data_struct`: Defines the data structure.
- `results_struct`: Defines the results structure.
Expand All @@ -94,109 +219,10 @@ This will install all the necessary dependencies and set up the StateSpaceAnalys
- `post_sse`: Defines the structure for posterior sum of squared errors.
### `utils/utils.jl`
- `tol_PD`: Ensures a matrix is positive definite with a tolerance.
- `tol_PSD`: Ensures a matrix is positive semi-definite with a tolerance.
- `demix`: Demixes the observations using the saved PCA transformation.
- `remix`: Remixes the observations using the saved PCA transformation.






## Running the Example
To run the example fitting script, follow these steps:

1. Set the paths in `example/fit_example.jl`:
```julia
run_cluster = length(ARGS)!=0;
if run_cluster
src_path = "/home/hr0283/HallM_StateSpaceAnalysis/src"
save_path = "/scratch/gpfs/hr0283/HallM_StateSpaceAnalysis/src";
else
src_path = "/Users/hr0283/Projects/StateSpaceAnalysis.jl/src"
save_path = "/Users/hr0283/Projects/StateSpaceAnalysis.jl/example";
end
```
2. Load the necessary packages and configure the system:
```julia
using StateSpaceAnalysis
using Accessors
using Random
using LinearAlgebra
using Dates
using Revise
```
3. Set the parameters and data structure:
```julia
S = core_struct(
prm=param_struct(
seed = rand_seed,
model_name = "test",
changelog = "run test",
load_name = "HallMcMaster2019_ITI100-Cue200-ISI400-Trial200_srate@125_filt@0-30",
load_path = "/Users/hr0283/Projects/StateSpaceAnalysis.jl/example/example-data",
pt_list = 1:1,
max_iter_em = 500,
ssid_fit = "fit",
ssid_save = false,
ssid_type = :CVA,
ssid_lag = 24,
),
dat=data_struct(
sel_event = 2:4,
pt = 1,
x_dim = 24,
basis_name = "bspline",
spline_gap = 5,
),
res=results_struct(),
est=estimates_struct(),
mdl=model_struct(),
);
```
4. Run the fitting process:
```julia
@reset S.res.startTime_all = Dates.format(now(), "mm/dd/yyyy HH:MM:SS");
println("Starting fit at $(S.res.startTime_all)")
@reset S = StateSpaceAnalysis.preprocess_fit(S);
if S.prm.ssid_fit == "fit"
@reset S = StateSpaceAnalysis.launch_SSID(S);
elseif S.prm.ssid_fit == "load"
@reset S = StateSpaceAnalysis.load_SSID(S);
end
@reset S = StateSpaceAnalysis.launch_EM(S);
@reset S.res.endTime_all = Dates.format(now(), "mm/dd/yyyy HH:MM:SS");
println("Finished fit at $(Dates.format(now(), "mm/dd/yyyy HH:MM:SS"))")
```
5. Optionally, plot diagnostics and save the fit:
```julia
do_plots = false
if do_plots
try
StateSpaceAnalysis.plot_loglik_traces(S)
StateSpaceAnalysis.plot_avg_pred(S)
StateSpaceAnalysis.plot_params(S)
catch
end
end
if S.prm.do_save
println("\n========== SAVING FIT ==========")
StateSpaceAnalysis.save_results(S)
end
```
This will run the example fitting script, performing SSID and EM fitting on the provided data.
2 changes: 1 addition & 1 deletion src/StateSpaceAnalysis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ export core_struct, param_struct, data_struct, results_struct, estimates_struct
set_model, transform_model

include("setup/generate.jl")
export gen_rand_params, generate_ssm_trials
export generate_rand_params, generate_ssm_trials


# utility functions
Expand Down
2 changes: 1 addition & 1 deletion src/fit/wrapper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ function preprocess_fit(S)
@reset S.est = deepcopy(set_estimates(S));

# init model
@reset S = deepcopy(gen_rand_params(S));
@reset S = deepcopy(generate_rand_params(S));
# =======================================================================


Expand Down
2 changes: 1 addition & 1 deletion src/setup/generate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


# init random parameters
function gen_rand_params(S)
function generate_rand_params(S)

A = Matrix(Diagonal(rand(S.dat.x_dim)));
B = randn(S.dat.x_dim, S.dat.u_dim);
Expand Down

0 comments on commit 136c0b9

Please sign in to comment.