From 6c22c50b676679d08c4989bd1952b1982faee78d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 21 Jun 2023 17:18:12 -0400 Subject: [PATCH] Fix JET Failures --- Project.toml | 2 +- src/api/dropout.jl | 49 +++++++++++++++++++++++++++------------------ test/api/dropout.jl | 10 ++++----- 3 files changed, 35 insertions(+), 26 deletions(-) diff --git a/Project.toml b/Project.toml index 23fbaacd..d4c272e7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.2.3" +version = "0.2.4" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/api/dropout.jl b/src/api/dropout.jl index cd741865..5407c0e8 100644 --- a/src/api/dropout.jl +++ b/src/api/dropout.jl @@ -1,7 +1,7 @@ @doc doc""" - dropout(rng::AbstractRNG, x, p, ::Val{training}; dims, invp=inv(p)) - dropout(rng::AbstractRNG, x, mask, p, ::Val{training}, ::Val{update_mask}; dims, - invp=inv(p)) + dropout(rng::AbstractRNG, x, p, ::Val{training}, invp; dims) + dropout(rng::AbstractRNG, x, mask, p, ::Val{training}, ::Val{update_mask}, invp; + dims) Dropout: Simple Way to prevent Neural Networks for Overfitting. For details see [1]. @@ -15,6 +15,7 @@ Dropout: Simple Way to prevent Neural Networks for Overfitting. For details see `dims`. Else, `x` is returned - `Val(update_mask)`: If `true` then the mask is generated and used. Else, the `mask` provided is directly used + - `invp`: Inverse of the probability ## Keyword Arguments @@ -32,19 +33,16 @@ Dropout: Simple Way to prevent Neural Networks for Overfitting. For details see [1] Srivastava, Nitish, et al. "Dropout: a simple way to prevent neural networks from overfitting." The journal of machine learning research 15.1 (2014): 1929-1958. """ -function dropout(rng::AbstractRNG, x::AA, p::T, ::Val{true}; dims, invp::T=inv(p)) where {T} +function dropout(rng::AbstractRNG, x::AA, p::T, ::Val{true}, invp::T; dims) where {T} rng = _replicate(rng) mask = _generate_dropout_mask(rng, x, p, invp; dims) return (x .* ignore_derivatives(mask), mask, rng) end -function dropout(rng::AbstractRNG, - x::AA, - p::T, - ::Val{false}; - dims, - invp::T=inv(p)) where {T} - return (x, x, rng) +dropout(rng::AbstractRNG, x::AA, p::T, ::Val{false}, ::T; dims) where {T} = (x, x, rng) + +function dropout(rng::AbstractRNG, x::AA, p::T, t::Val; dims, invp::T=inv(p)) where {T} + return dropout(rng, x, p, t, invp; dims) end function dropout(rng::AbstractRNG, @@ -52,9 +50,9 @@ function dropout(rng::AbstractRNG, mask::AA, p::T, t::Val, - ::Val{true}; - dims, - invp::T=inv(p)) where {T} + ::Val{true}, + invp::T; + dims) where {T} return dropout(rng, x, p, t; dims, invp) end @@ -63,9 +61,9 @@ function dropout(rng::AbstractRNG, mask::AA{T2, N}, p::T, ::Val{true}, - ::Val{false}; - dims, - invp::T=inv(p)) where {T, T1, T2, N} + ::Val{false}, + invp::T; + dims) where {T, T1, T2, N} size(x) != size(mask) && return dropout(rng, x, p, Val(true); dims, invp) return x .* ignore_derivatives(mask), mask, rng end @@ -75,10 +73,21 @@ function dropout(rng::AbstractRNG, mask::AA{T2, N}, p::T, ::Val{false}, - ::Val{false}; + ::Val{false}, + invp::T; + dims) where {T, T1, T2, N} + return (x, mask, rng) +end + +function dropout(rng::AbstractRNG, + x::AA{T1, N}, + mask::AA{T2, N}, + p::T, + t::Val, + um::Val; dims, invp::T=inv(p)) where {T, T1, T2, N} - return (x, mask, rng) + return dropout(rng, x, mask, p, t, um, invp; dims) end @doc doc""" @@ -139,7 +148,7 @@ alpha_dropout(rng::AbstractRNG, x::AA, p, ::Val{false}, α, A, B) = (x, rng) return tuple((i in dims ? si : 1 for (i, si) in enumerate(size(s)))...) end -@inline _dropout_kernel(y, p, invp) = y > p ? invp : oftype(y, 0) +@inline _dropout_kernel(y, p, invp) = ifelse(y > p, invp, oftype(y, 0)) @inline _dropout_fptype(x) = float(real(eltype(x))) diff --git a/test/api/dropout.jl b/test/api/dropout.jl index c941a4c6..2ddcb65c 100644 --- a/test/api/dropout.jl +++ b/test/api/dropout.jl @@ -24,7 +24,7 @@ rng = get_stable_rng(12345) fp16 = T == Float16 @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu - @jet __f(x) + @jet sum(first(dropout(rng, x, T(0.5), Val(true); dims=Colon()))) @inferred dropout(rng, x, T(0.5), Val(true); dims=Colon()) @@ -66,7 +66,7 @@ end fp16 = T == Float16 @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu - @jet __f(x) + @jet sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(true); dims=Colon()))) # Try using mask if possible (possible!!) @inferred dropout(rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon()) @@ -90,7 +90,7 @@ end fp16 = T == Float16 @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu - @jet __f(x) + @jet sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon()))) mask = rand(T, (x_shape[1:(end - 1)]..., 13)) |> aType @@ -116,7 +116,7 @@ end fp16 = T == Float16 @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu - @jet __f(x) + @jet sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon()))) # Testing Mode @inferred dropout(rng, x, mask, T(0.5), Val(false), Val(false); dims=Colon()) @@ -151,7 +151,7 @@ end fp16 = T == Float16 @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu - @jet __f(x) + @jet sum(first(alpha_dropout(rng, x, T(0.5), Val(true)))) @inferred alpha_dropout(rng, x, T(0.5), Val(false))