From af1d16894452dae3d50b840dc56d647b1cf95fde Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Sun, 30 Jul 2023 18:57:13 +0100 Subject: [PATCH 1/3] Sketch for batched distribution --- src/batch.jl | 83 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) create mode 100644 src/batch.jl diff --git a/src/batch.jl b/src/batch.jl new file mode 100644 index 00000000..fd25715f --- /dev/null +++ b/src/batch.jl @@ -0,0 +1,83 @@ +using Bijectors +using CUDA +using Distributions +using LinearAlgebra +using MacroTools +using Random +using Functors +using Flux +## +struct BatchDistributionWrapper{D<:Distribution, T<:AbstractArray} + distribution::Type{D} + parameters::NTuple{M, T} where M + batch_shape::NTuple{N, Int} where N +end + +@functor BatchDistributionWrapper (parameters, ) + +function BatchDistributionWrapper(dist::Symbol, params, batch_shape=()) + if any((iszero∘ndims), params) # if any of the parameters is a scalar, just return the Distribution + return getfield(Distributions, dist)(params...) + end + + @assert isdefined(Distributions, dist) "Distribution $dist is not defined" + + @assert all(map(x->eltype(x) != Any, params)) "all parameters should have the concrete element type" + @assert all(map(x->eltype(x) == eltype(params[1]), params)) "all parameters should have the same element type" + + D = getfield(Distributions, dist) + return BatchDistributionWrapper{D, eltype(params)}(D, params, batch_shape) +end + +macro batch(dist_with_args, batch_shape) + c = @capture(dist_with_args, dist_name_(args__)) + @assert c "$dist_with_args should be a distribution with arguments" + quote + BatchDistributionWrapper($(Meta.quot(dist_name)), (tuple($(args...))), $(batch_shape)) + end +end + +d = @batch Normal(zeros(2, 2, 1), ones(2, 2, 1)) (2, 2) + +# Default implementation +function Random.rand(rng::AbstractRNG, d::BatchDistributionWrapper{D, T}) where {D, T} + dists = D.(d.parameters...) + return map(rand, dists) +end + +function Random.rand(rng::AbstractRNG, d::BatchDistributionWrapper{D, T}, sample_shape::Tuple{Vararg{Int, N}}) where {D, T, N} + dists = D.(d.parameters...) + event_shape = length(rand(rng, dists[1])) + samples = map(d -> reshape(rand(rng, d, sample_shape), (sample_shape..., 1)), dists) + reshaped_samples = reshape(cat(samples..., dims=N+1), (sample_shape..., d.batch_shape..., event_shape)) + return reshaped_samples +end + +function Random.rand(rng::AbstractRNG, d::BatchDistributionWrapper{D, T}) where {D<:Normal, T<:AbstractArray} + μ, σ = d.parameters + x = similar(μ) + rand!(x) + x .*= σ + x .+= μ + return reshape(x, d.batch_shape) +end + +function Random.rand(rng::AbstractRNG, d::BatchDistributionWrapper{D, T}) where {D<:Normal, T<:CuArray} + μ, σ = d.parameters + x = similar(μ) + CUDA.rand!(x) + x .*= σ + x .+= μ + return reshape(x, d.batch_shape) +end + +function Distributions.logpdf(d::BatchDistributionWrapper{D, T}, x::AbstractArray) where {D, T} + dists = D.(d.parameters...) + return reshape(map(logpdf, dists, x), d.batch_shape) +end + +# both CPU and GPU +function Distributions.logpdf(d::BatchDistributionWrapper{D, T}, x::AbstractArray) where {D <: Normal, T} + μ, σ = d.parameters + return -0.5 * (((x .- μ) ./ σ).^2 .+ log(2π) .+ 2 .* log.(σ)) +end From 4c45f125adb1ee33bfb32cb47e064d936f612e41 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Sun, 30 Jul 2023 19:02:43 +0100 Subject: [PATCH 2/3] No reshape with `rand` --- src/batch.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/batch.jl b/src/batch.jl index fd25715f..d16935d1 100644 --- a/src/batch.jl +++ b/src/batch.jl @@ -59,7 +59,7 @@ function Random.rand(rng::AbstractRNG, d::BatchDistributionWrapper{D, T}) where rand!(x) x .*= σ x .+= μ - return reshape(x, d.batch_shape) + return x end function Random.rand(rng::AbstractRNG, d::BatchDistributionWrapper{D, T}) where {D<:Normal, T<:CuArray} @@ -68,9 +68,11 @@ function Random.rand(rng::AbstractRNG, d::BatchDistributionWrapper{D, T}) where CUDA.rand!(x) x .*= σ x .+= μ - return reshape(x, d.batch_shape) + return x end +rand(gpu(d)) + function Distributions.logpdf(d::BatchDistributionWrapper{D, T}, x::AbstractArray) where {D, T} dists = D.(d.parameters...) return reshape(map(logpdf, dists, x), d.batch_shape) From e5655b40ec81bab00d3464e38aacb823b04d635d Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Sun, 30 Jul 2023 19:09:54 +0100 Subject: [PATCH 3/3] remove extra code --- src/batch.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/batch.jl b/src/batch.jl index d16935d1..cc2aa0d7 100644 --- a/src/batch.jl +++ b/src/batch.jl @@ -71,8 +71,6 @@ function Random.rand(rng::AbstractRNG, d::BatchDistributionWrapper{D, T}) where return x end -rand(gpu(d)) - function Distributions.logpdf(d::BatchDistributionWrapper{D, T}, x::AbstractArray) where {D, T} dists = D.(d.parameters...) return reshape(map(logpdf, dists, x), d.batch_shape)