Skip to content

Commit

Permalink
test: try fixing enzyme test
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 4, 2024
1 parent 921abf3 commit 7b714c1
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions test/layers/basic_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,15 @@ end
)
x = randn(SVector{N, Float64})

fun = let d = d, x = x
ps -> sum(d(x, ps, (;))[1])
grad1 = ForwardDiff.gradient(ComponentArray(ps)) do ps
sumabs2first(d, x, ps, (;))
end
grad1 = ForwardDiff.gradient(fun, ComponentVector(ps))
grad2 = Enzyme.gradient(Enzyme.Reverse, fun, ps)[1]
@test maximum(abs, grad1 .- ComponentVector(grad2)) < 1e-6

grad2 = Enzyme.gradient(
Enzyme.Reverse, sumabs2first, Const(d), Const(x), ps, Const((;))
)[3]

@test maximum(abs, grad1 .- ComponentArray(grad2)) < 1e-6
end
end

Expand Down

0 comments on commit 7b714c1

Please sign in to comment.