Skip to content

Commit

Permalink
rec loss
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Dec 11, 2024
1 parent 8d946d0 commit b9eab55
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Expand Down Expand Up @@ -48,12 +49,13 @@ CUDA = "5"
ChainRulesCore = "1.12"
Compat = "4.10.0"
Enzyme = "0.13"
Functors = "0.5"
EnzymeCore = "0.7.7, 0.8.4"
Functors = "0.5"
MLDataDevices = "1.4.2"
MLUtils = "0.4"
MPI = "0.20.19"
MacroTools = "0.5"
Metal = "1.4.2"
NCCL = "0.1.1"
NNlib = "0.9.22"
OneHotArrays = "0.2.4"
Expand Down
14 changes: 11 additions & 3 deletions test/ext_common/recurrent_gpu_ad.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
function recurrent_cell_loss(cell, x, state)
out = Flux.scan(cell, x, state)
return mean(out)
out_from_state(state::Tuple) = state[1]
out_from_state(state) = state

function recurrent_cell_loss(cell, seq, state)
out = []
for xt in seq
state = Flux.scan(cell, x, state)
yt = out_from_state(state)
out = vcat(out, [yt])
end
return mean(stack(y, dims = 2))
end

@testset "RNNCell GPU AD" begin
Expand Down

0 comments on commit b9eab55

Please sign in to comment.