Skip to content

Commit

Permalink
k-dimensional OU process
Browse files Browse the repository at this point in the history
  • Loading branch information
bicycle1885 committed Oct 16, 2023
1 parent 0d786c7 commit 4ba2714
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 12 deletions.
27 changes: 18 additions & 9 deletions src/continuous.jl
Original file line number Diff line number Diff line change
@@ -1,31 +1,40 @@
#OU process
struct OrnsteinUhlenbeckDiffusion{T <: Real} <: GaussianStateProcess
# OU process
struct OrnsteinUhlenbeckDiffusion{T} <: GaussianStateProcess
mean::T
volatility::T
reversion::T
end

OrnsteinUhlenbeckDiffusion(mean::Real, volatility::Real, reversion::Real) = OrnsteinUhlenbeckDiffusion(float.(promote(mean, volatility, reversion))...)
function OrnsteinUhlenbeckDiffusion(mean::Real, volatility::Real, reversion::Real)
μ, σ, θ = float.(promote(mean, volatility, reversion))
return OrnsteinUhlenbeckDiffusion{typeof(μ)}(μ, σ, θ)
end

OrnsteinUhlenbeckDiffusion(mean::T) where T <: Real = OrnsteinUhlenbeckDiffusion(mean,T(1.0),T(0.5))

var(model::OrnsteinUhlenbeckDiffusion) = (model.volatility^2) / (2 * model.reversion)

eq_dist(model::OrnsteinUhlenbeckDiffusion) = Normal(model.mean,sqrt(var(model)))

function forward(process::OrnsteinUhlenbeckDiffusion, x_s::AbstractArray, s::Real, t::Real)
# These are for nested broadcasting
elmwisesqrt(x) = sqrt.(x)
elmwiseinv(x) = inv.(x)
elmwisemul(x, y) = x .* y
elmwisediv(x, y) = x ./ y

function forward(process::OrnsteinUhlenbeckDiffusion{T}, x_s::AbstractArray{T}, s::Real, t::Real) where T
μ, σ, θ = process.mean, process.volatility, process.reversion
mean = @. exp(-(t - s) * θ) * (x_s - μ) + μ
mean = elmwisemul.((exp.(-(t - s) * θ),), x_s .- (μ,)) .+ (μ,)
var = similar(mean)
var .= ((1 - exp(-2(t - s) * θ)) * σ^2) / 2θ
fill!(var, @. ((1 - exp(-2(t - s) * θ)) * σ^2) / 2θ)
return GaussianVariables(mean, var)
end

function backward(process::OrnsteinUhlenbeckDiffusion, x_t::AbstractArray, s::Real, t::Real)
function backward(process::OrnsteinUhlenbeckDiffusion{T}, x_t::AbstractArray{T}, s::Real, t::Real) where T
μ, σ, θ = process.mean, process.volatility, process.reversion
mean = @. exp((t - s) * θ) * (x_t - μ) + μ
mean = elmwisemul.((exp.((t - s) * θ),), x_t .- (μ,)) .+ (μ,)
var = similar(mean)
var .= -^2 / 2θ) +^2 * exp(2(t - s) * θ)) / 2θ
fill!(var, @. -^2 / 2θ) +^2 * exp(2(t - s) * θ)) / 2θ)
return= mean, σ² = var)
end

Expand Down
7 changes: 4 additions & 3 deletions src/randomvariable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@ end

Base.size(X::GaussianVariables) = size(X.μ)

sample(rng::AbstractRNG, X::GaussianVariables{T}) where T = randn(rng, T, size(X)) .* .√X.σ² .+ X.μ
sample(rng::AbstractRNG, X::GaussianVariables{T}) where T =
elmwisemul.(randn(rng, T, size(X)), elmwisesqrt.(X.σ²)) .+ X.μ

function combine(X::GaussianVariables, lik)
σ² = @. inv(inv(X.σ²) + inv(lik.σ²))
μ = @. σ² * (X.μ / X.σ² + lik.μ / lik.σ²)
σ² = elmwiseinv.(elmwiseinv.(X.σ²) .+ elmwiseinv.(lik.σ²))
μ = elmwisemul.(σ², elmwisediv.(X.μ, X.σ²) .+ elmwisediv.(lik.μ, lik.σ²))
return GaussianVariables(μ, σ²)
end

Expand Down
19 changes: 19 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,16 @@ end
x = one(QuatRotation{Float32})
t = 0.29999998f0
@test sampleforward(diffusion, t, [x]) isa Vector

# three-dimensional diffusion
μ = @SVector [0.0, 0.0, 0.0]
θ = @SVector [1.0, 1.0, 1.0]
σ = @SVector [0.5, 0.5, 0.5]
x_0 = [zero(μ), zero(μ)]
diffusion = OrnsteinUhlenbeckDiffusion(μ, σ, θ)
x_t = sampleforward(diffusion, 1.0, x_0)
@test x_t isa typeof(x_0)
@test size(x_t) == size(x_0)
end

@testset "Discrete Diffusions" begin
Expand Down Expand Up @@ -175,6 +185,15 @@ end
x = samplebackward((x, t) -> x + randn(size(x)), process, [1/8, 1/4, 1/2, 1/1], x_t)
@test size(x) == size(x_t)
@test x isa Matrix

μ = @SVector [0.0, 0.0, 0.0]
θ = @SVector [1.0, 1.0, 1.0]
σ = @SVector [0.5, 0.5, 0.5]
x_t = randn(typeof(μ), 4, 10)
process = OrnsteinUhlenbeckDiffusion(μ, σ, θ)
x = samplebackward((x, t) -> x + randn(eltype(x), size(x)), process, [1/8, 1/4, 1/2, 1/1], x_t)
@test size(x) == size(x_t)
@test x isa Matrix
end

@testset "Masked Diffusion" begin
Expand Down

2 comments on commit 4ba2714

@murrellb
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to check: When talking about a 2D state representation (where the 3rd can be a batch dim - but I'll leave the batch dim implied in the following), there are two different use cases with the same mathematical behaviour here.
Case 1: Where the sample is an image, where neither of the two image dims are privileged over the other. You would typically diffuse pixels all with the same t (although sometimes you might want a different t for each pixel), and if you mask part of the image, the mask needs to be 2D.
Case 2: Where the sample is a vector-valued sequence. Eg. a point in Euc space. So each column of the 2D matrix is a single vector valued observation, and your time is either a scalar, or a row vector, but never a full matrix, and your mask is a row vector.

I haven't looked in detail, but I just want to check that you haven't made case 2 easier to work with, but lost case 1. Also, would it make things easier to just make these different types? OrnsteinUhlenbeckDiffusion and OrnsteinUhlenbeckVectorDiffusion, or something like that?

@bicycle1885
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This branch is still work-in-progress, and the final change will be different.
My main idea is that we always represent the whole state as a multidimensional array of states, where each element represents a single state.
So if the sample is an image, the array will be a 2d array, where each element is an pixel, and when you batch multiple images, the array will have an additional dimension at the end.
If the sample is a vector-valued sequence, the array will be a 1d array, where each element is a vector to represent a state, and you can batch these arrays with an additional dimension in the same way as images.
Time is either a scalar or a 1d array (vector), the length of which must match the last dimension (which is for batching) of the state array.
In this idea, the masking array will always have the same shape as the state array, which imo simplifies the interface of the package.
Separating types might simplify the code, but I guess it'll result in a bunch of copy-and-pastings of the same code.

Please sign in to comment.