From 6f05a87647a05ab71c1602dffbb6d28d89dcd57e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 24 Aug 2022 17:37:16 +0100 Subject: [PATCH] Fix for bug in `forward` of `Stacked` (#192) * fixed a bug with the default forward for stacked * bump patch version * Update Project.toml Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> --- Project.toml | 4 ++-- src/bijectors/stacked.jl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index 7202822b..938d4735 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Bijectors" uuid = "76274a88-744f-5084-9051-94815aaf08c4" -version = "0.10.5" +version = "0.10.6" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" @@ -35,4 +35,4 @@ MappedArrays = "0.2.2, 0.3, 0.4" Reexport = "0.2, 1" Requires = "0.5, 1" Roots = "1.3.4, 2" -julia = "1.3" +julia = "1.6" diff --git a/src/bijectors/stacked.jl b/src/bijectors/stacked.jl index 7f5272b9..83605188 100644 --- a/src/bijectors/stacked.jl +++ b/src/bijectors/stacked.jl @@ -164,10 +164,10 @@ function with_logabsdet_jacobian(sb::Stacked, x::AbstractVector) N = length(sb.bs) yinit, linit = with_logabsdet_jacobian(sb.bs[1], x[sb.ranges[1]]) logjac = sum(linit) - ys = mapvcat(drop(sb.bs, 1), drop(sb.ranges, 1)) do b, r + ys = mapreduce(vcat, sb.bs[2:end], sb.ranges[2:end]; init=yinit) do b, r y, l = with_logabsdet_jacobian(b, x[r]) logjac += sum(l) y end - return (vcat(yinit, ys), logjac) + return (ys, logjac) end