From f96bd5810aac32ec79341c0c28463981d91d2d8f Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Wed, 11 Dec 2024 13:32:30 +0100 Subject: [PATCH] fix --- Project.toml | 2 -- test/ext_common/recurrent_gpu_ad.jl | 4 ++-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index 352b33c44e..ade4595b71 100644 --- a/Project.toml +++ b/Project.toml @@ -12,7 +12,6 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -Metal = "dde4c033-4e86-420c-a63e-0dd931031962" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" @@ -55,7 +54,6 @@ MLDataDevices = "1.4.2" MLUtils = "0.4" MPI = "0.20.19" MacroTools = "0.5" -Metal = "1.4.2" NCCL = "0.1.1" NNlib = "0.9.22" OneHotArrays = "0.2.4" diff --git a/test/ext_common/recurrent_gpu_ad.jl b/test/ext_common/recurrent_gpu_ad.jl index af7215537f..3046f1f1fa 100644 --- a/test/ext_common/recurrent_gpu_ad.jl +++ b/test/ext_common/recurrent_gpu_ad.jl @@ -4,11 +4,11 @@ out_from_state(state) = state function recurrent_cell_loss(cell, seq, state) out = [] for xt in seq - state = Flux.scan(cell, x, state) + state = cell(xt, state) yt = out_from_state(state) out = vcat(out, [yt]) end - return mean(stack(y, dims = 2)) + return mean(stack(out, dims = 2)) end @testset "RNNCell GPU AD" begin