Skip to content

Commit

Permalink
Merge pull request #5 from harrisonritz/dev
Browse files Browse the repository at this point in the history
update to 0.2.0
  • Loading branch information
harrisonritz authored Dec 18, 2024
2 parents 136c0b9 + e065484 commit 46c502b
Show file tree
Hide file tree
Showing 17 changed files with 643 additions and 242 deletions.
4 changes: 1 addition & 3 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@ jobs:
fail-fast: false
matrix:
version:
- '1.10'
- '1.6'
- 'pre'
- '1.11.0'
os:
- ubuntu-latest
arch:
Expand Down
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/Manifest.toml
Manifest.toml
/example/fit-results
.DS_Store
/*/fit-results
Expand Down
17 changes: 10 additions & 7 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
name = "StateSpaceAnalysis"
uuid = "8767b432-2e83-436f-aa62-e8e2db78ab85"
authors = ["Harrison Ritz <[email protected]> and contributors"]
version = "0.1.0"
version = "0.2.0"


[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand All @@ -12,6 +13,7 @@ ControlSystems = "a6e380b2-a6ca-5380-bf3e-84a91bcd477e"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MAT = "23992714-dd62-5051-b70f-ba57cb901cac"
MATLAB = "10e44e05-a98a-55b3-a45b-ba969058deb6"
Expand All @@ -31,23 +33,24 @@ Aqua = "0.8.9"
BSplines = "0.3.3"
ControlSystemIdentification = "2.11.0"
ControlSystems = "1.10.5"
Dates = "1.11.0"
Dates = "1.10.7"
Distributions = "0.25.113"
FileIO = "1.16.5"
LinearAlgebra = "1.11.0"
FunctionWrappers = "1.1.3"
LinearAlgebra = "1.10.7"
MAT = "0.10.7"
MATLAB = "0.8.4"
MultivariateStats = "0.10.3"
OffsetArrays = "1.14.1"
PDMats = "0.11.31"
Random = "1.11.0"
Random = "1.10.7"
Revise = "3.6.3"
Serialization = "1.11.0"
Serialization = "1.10.7"
SpecialFunctions = "2.4.0"
StatsBase = "0.34.3"
StatsFuns = "1.3.2"
Test = "1.11"
julia = "1.11.0"
Test = "1.10"
julia = "1.10.7"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
76 changes: 48 additions & 28 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,33 +10,42 @@
This package provides tools for preprocessing data, fitting models, and evaluating model performance, with methods especially tailored towards neuroimaging analysis:


### Event-related designs
### Event-related data

Neuroimaging data often has epoched/batched sequences (e.g., states x timesteps x trials). *StateSpaceAnalysis.jl* handles epoched data by re-using computations across batches, and it includes spline temporal bases for flexible input modeling over the epoch.
Neuroimaging data often has epoched/batched sequences (e.g., states x timesteps x trials). *StateSpaceAnalysis.jl* handles epoched data by reusing computations across batches, and it includes spline temporal bases for flexible input modeling over the epoch.


### High-dimensional Systems
### High-dimensional systems

Whole-brain modelling may require a large number of latent factors. *StateSpaceAnalysis.jl* handles scaling through efficient memory allocation, robust covariance formats (via [*PDMats.jl*](https://github.com/JuliaStats/PDMats.jl)), and regularization.


### Data-driven Initialization
### Data-driven initialization

We need good initialization for systems for which we don't have great domain knowledge (especially when there are many latent factors). *StateSpaceAnalysis.jl* handles parameter initialization through subspace identification methods from [*ControlSystemsIdentification.jl*](https://github.com/baggepinnen/ControlSystemIdentification.jl).
We need good initialization for systems for which we don't have great domain knowledge (especially when there are many latent factors). *StateSpaceAnalysis.jl* handles parameter initialization through subspace identification methods adapted from [*ControlSystemsIdentification.jl*](https://github.com/baggepinnen/ControlSystemIdentification.jl).

This version is very preliminary, so there may be some rough edges!

## Installation

To install the *StateSpaceAnalysis.jl* package, follow these steps:
You can easily install the current release of *StateSpaceAnalysis.jl* from the Julia General Registry:

```julia
using Pkg
Pkg.add("StateSpaceAnalysis")
```


You may want to work directly with the package, e.g., to modify custom functions for setting up your input bases.
You can create a local copy by cloning the github repo:

1. **Clone the repository:**
```sh
git clone https://github.com/harrisonritz/StateSpaceAnalysis.jl.git
cd StateSpaceAnalysis.jl
```
```

2. **Open Julia and activate the package environment:**
2. **Open Julia in the folder and activate the package environment:**
```julia
using Pkg
Pkg.activate(".")
Expand All @@ -49,7 +58,10 @@ To install the *StateSpaceAnalysis.jl* package, follow these steps:
using StateSpaceAnalysis
```

This will install all the necessary dependencies and set up the StateSpaceAnalysis.jl package for use.
This will install all the necessary dependencies and set up the StateSpaceAnalysis.jl package for local use.

Note: You can check which directory you are working in with `pwd()` in Julia. Opening a folder in VS code sets that folder to your path. You can specify the paths in `Pkg.activate("path/to/package")` and `Pkg.add("path/to/package")` even in you aren't in the right folder.
## Walkthrough of the `example/fit_example.jl` script
Expand All @@ -58,22 +70,25 @@ This will install all the necessary dependencies and set up the StateSpaceAnalys
```julia
S = core_struct(
prm=param_struct(
... # high-level parameters
),
dat=data_struct(
... # data and data description
),
res=results_struct(
... # fit metrics and model derivates
),
est=estimates_struct(
... # scratch space
),
mdl=model_struct(
... # estimated model parameters
prm=param_struct(
... # high-level parameters
),
dat=data_struct(
... # data and data description
),
);
res=results_struct(
... # fit metrics and model derivates
),
est=estimates_struct(
... # scratch space
),
mdl=model_struct(
... # estimated model parameters
),
fnc=function_struct{core_struct}(
... # custom functions for setup
)
);
```
This structure is used throughout the script, which allows for effective memory management (i.e., the complier can know the size of the data tensors).
Expand All @@ -82,9 +97,10 @@ This structure is used throughout the script, which allows for effective memory
```julia
@reset S = StateSpaceAnalysis.preprocess_fit(S);
```
Preprocessing steps within `preprocess_fit(S)`:
```julia
# read in arguements, helpful for running on a cluster
# read in arguments, helpful for running on a cluster
S = deepcopy(StateSpaceAnalysis.read_args(S, ARGS));
# set up the paths
Expand All @@ -97,7 +113,7 @@ S = deepcopy(StateSpaceAnalysis.load_data(S));
S = deepcopy(StateSpaceAnalysis.build_inputs(S));
# transform the observed data (PCA)
S = deepcopy(StateSpaceAnalysis.whiten(S));
S = deepcopy(StateSpaceAnalysis.project(S));
# fit baseline models to the data
StateSpaceAnalysis.null_loglik!(S);
Expand All @@ -107,6 +123,10 @@ StateSpaceAnalysis.null_loglik!(S);
@reset S = deepcopy(generate_rand_params(S));
```
These preprocessing steps depend on custom code that you can modify.
# TODO!
### Warm-start the EM with initial parameters from Subspace Identification (SSID):
```julia
Expand Down Expand Up @@ -195,11 +215,11 @@ StateSpaceAnalysis.save_results(S)
### `setup/setup.jl`
- `read_args`: process command line arguements (for running on the cluster)
- `read_args`: process command line arguments (for running on the cluster)
- `setup_path`: Sets up the directory paths for saving results.
- `load_data`: Loads the data from files.
- `build_inputs`: Builds the input matrices for the model.
- `whiten`: Whitens the observations (PCA).
- `project`: projects the observations (PCA).
### `setup/generate.jl`
Expand Down
4 changes: 4 additions & 0 deletions example/fit_example.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ using Random
using LinearAlgebra
using Dates
using Revise # this is for development
import FunctionWrappers: FunctionWrapper

# =============================================================


Expand Down Expand Up @@ -110,6 +112,8 @@ S = core_struct(

mdl=model_struct(),

fcn=function_struct{core_struct}(),

);
println("--- changelog: ",S.prm.changelog, " ---\n\n")
# =======================================================================
Expand Down
40 changes: 24 additions & 16 deletions src/StateSpaceAnalysis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ using ControlSystems
using ControlSystemIdentification
using BSplines
using OffsetArrays
using FunctionWrappers
import FunctionWrappers: FunctionWrapper







try
Expand All @@ -28,48 +35,49 @@ end



# custom functions
include("setup/custom.jl")
export assign_arguements, select_trials, scale_input, create_input_basis, launch_EM, load_SSID

include(joinpath("setup","structs.jl"))
export core_struct, param_struct, data_struct, results_struct, estimates_struct, set_estimates, model_struct,
set_model, transform_model, function_struct


# custom functions
include(joinpath("setup","custom.jl"))
# export assign_arguments, select_trials, scale_input, create_input_basis, launch_EM, load_SSID

# fit functions
include("fit/wrapper.jl")
include(joinpath("fit","wrapper.jl"))
export preprocess_fit, launch_SSID, launch_EM, load_SSID, save_SSID, save_results

include("fit/EM.jl")
include(joinpath("fit","EM.jl"))
export fit_EM, ESTEP!, MSTEP, estimate_cov!, filter_cov!, filter_cov_KF!, smooth_cov!,
estimate_mean!, filter_mean!, filter_mean_KF!, smooth_mean!, init_moments!, estimate_moments!,
total_loglik!, total_loglik, test_loglik!, test_loglik, test_orig_loglik, null_loglik!

include("fit/SSID.jl")
include(joinpath("fit","SSID.jl"))

include("fit/likelihoods.jl")
include(joinpath("fit","likelihoods.jl"))
export ll_R2, log_post_v0, log_post, init_lik, dyn_lik, obs_lik,
total_loglik!, total_loglik,
test_loglik!, test_loglik,
test_orig_loglik, null_loglik!

include("fit/posteriors.jl")
include(joinpath("fit","posteriors.jl"))
export posterior_all, posterior_mean, posterior_sse


# setup functions
include("setup/setup.jl")
export read_args, setup_path, load_data, build_inputs, whiten, save_results
include(joinpath("setup","setup.jl"))
export read_args, setup_path, load_data, build_inputs, project, save_results

include("setup/structs.jl")
export core_struct, param_struct, data_struct, results_struct, estimates_struct, set_estimates, model_struct,
set_model, transform_model

include("setup/generate.jl")
include(joinpath("setup","generate.jl"))
export generate_rand_params, generate_ssm_trials


# utility functions
include("utils/utils.jl")
include(joinpath("utils","utils.jl"))
export zsel, zsel_tall, zdim, init_PD, tol_PD, init_PSD, tol_PSD, diag_PD, format_noise, sumsqr, split_list, demix, remix
export format_B_preSSID, format_B_postSSID
export report_R2, report_params


Expand Down
Loading

0 comments on commit 46c502b

Please sign in to comment.