Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ReverseDiff tracks through length #418

Open
marcoct opened this issue May 19, 2021 · 0 comments
Open

ReverseDiff tracks through length #418

marcoct opened this issue May 19, 2021 · 0 comments

Comments

@marcoct
Copy link
Collaborator

marcoct commented May 19, 2021

Note that this is an issue that occurs on branch of #417, and not master, because this code fails for a different reason that #417 fixes.

The following code

using Gen

@gen (static) function foo()
    @param b::Vector{Float32}
    n = length(b)
    x = zeros(n)
    a ~ normal(sum(x), 1.0)
    return nothing
end

@load_generated_functions()

init_parameter!((foo, :b), [0.0, 0.0])
trace = simulate(foo, ())
accumulate_param_gradients!(trace)

produces the error:

ERROR: LoadError: MethodError: no method matching zeros(::ReverseDiff.TrackedReal{Int64, Int64, Nothing})
Closest candidates are:
  zeros(::Union{Integer, AbstractUnitRange}...) at array.jl:498
  zeros(::Tuple{Vararg{Union{Integer, AbstractUnitRange}, N} where N}) at array.jl:500
  zeros(::Type{StaticArrays.MVector{N, T} where T}) where N at /home/marcoct/.julia/packages/StaticArrays/xV8rq/src/MVector.jl:25
  ...
Stacktrace:
 [1] (::var"#2#7")(n::ReverseDiff.TrackedReal{Int64, Int64, Nothing})
   @ Main ./none:0
 [2] macro expansion
   @ ~/.julia/packages/Gen/3mYgc/src/static_ir/backprop.jl:0 [inlined]
 [3] accumulate_param_gradients!(trace::var"##StaticIRTrace_foo#270", retval_grad::Nothing, scale_factor::Float64)
   @ Main ~/.julia/packages/Gen/3mYgc/src/static_ir/backprop.jl:549
 [4] accumulate_param_gradients!(trace::var"##StaticIRTrace_foo#270")
   @ Gen ~/.julia/packages/Gen/3mYgc/src/gen_fn_interface.jl:403
 [5] top-level scope
   @ ~/dev/GenExamples.jl/test/test.jl:15
in expression starting at /home/marcoct/dev/GenExamples.jl/test/test.jl:15

A careful redesign of how ReverseDiff is used for AD is probably needed. (ReverseDiff is currently being used as a stop-gap because it provides differentiation of arithmetic and linear algebra operations, and support for AD of new operations should be added by writing generative functions -- e.g. using https://www.gen.dev/dev/ref/extending/#Gen.CustomGradientGF -- instead of by extending ReverseDiff).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant