Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
Fix JET Failures
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jun 21, 2023
1 parent 0fe39c0 commit 6c22c50
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 26 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LuxLib"
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.2.3"
version = "0.2.4"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
49 changes: 29 additions & 20 deletions src/api/dropout.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
@doc doc"""
dropout(rng::AbstractRNG, x, p, ::Val{training}; dims, invp=inv(p))
dropout(rng::AbstractRNG, x, mask, p, ::Val{training}, ::Val{update_mask}; dims,
invp=inv(p))
dropout(rng::AbstractRNG, x, p, ::Val{training}, invp; dims)
dropout(rng::AbstractRNG, x, mask, p, ::Val{training}, ::Val{update_mask}, invp;
dims)
Dropout: Simple Way to prevent Neural Networks for Overfitting. For details see [1].
Expand All @@ -15,6 +15,7 @@ Dropout: Simple Way to prevent Neural Networks for Overfitting. For details see
`dims`. Else, `x` is returned
- `Val(update_mask)`: If `true` then the mask is generated and used. Else, the `mask`
provided is directly used
- `invp`: Inverse of the probability
## Keyword Arguments
Expand All @@ -32,29 +33,26 @@ Dropout: Simple Way to prevent Neural Networks for Overfitting. For details see
[1] Srivastava, Nitish, et al. "Dropout: a simple way to prevent neural networks from
overfitting." The journal of machine learning research 15.1 (2014): 1929-1958.
"""
function dropout(rng::AbstractRNG, x::AA, p::T, ::Val{true}; dims, invp::T=inv(p)) where {T}
function dropout(rng::AbstractRNG, x::AA, p::T, ::Val{true}, invp::T; dims) where {T}
rng = _replicate(rng)
mask = _generate_dropout_mask(rng, x, p, invp; dims)
return (x .* ignore_derivatives(mask), mask, rng)
end

function dropout(rng::AbstractRNG,
x::AA,
p::T,
::Val{false};
dims,
invp::T=inv(p)) where {T}
return (x, x, rng)
dropout(rng::AbstractRNG, x::AA, p::T, ::Val{false}, ::T; dims) where {T} = (x, x, rng)

function dropout(rng::AbstractRNG, x::AA, p::T, t::Val; dims, invp::T=inv(p)) where {T}
return dropout(rng, x, p, t, invp; dims)
end

function dropout(rng::AbstractRNG,
x::AA,
mask::AA,
p::T,
t::Val,
::Val{true};
dims,
invp::T=inv(p)) where {T}
::Val{true},
invp::T;
dims) where {T}
return dropout(rng, x, p, t; dims, invp)
end

Expand All @@ -63,9 +61,9 @@ function dropout(rng::AbstractRNG,
mask::AA{T2, N},
p::T,
::Val{true},
::Val{false};
dims,
invp::T=inv(p)) where {T, T1, T2, N}
::Val{false},
invp::T;
dims) where {T, T1, T2, N}
size(x) != size(mask) && return dropout(rng, x, p, Val(true); dims, invp)
return x .* ignore_derivatives(mask), mask, rng
end
Expand All @@ -75,10 +73,21 @@ function dropout(rng::AbstractRNG,
mask::AA{T2, N},
p::T,
::Val{false},
::Val{false};
::Val{false},
invp::T;
dims) where {T, T1, T2, N}
return (x, mask, rng)
end

function dropout(rng::AbstractRNG,
x::AA{T1, N},
mask::AA{T2, N},
p::T,
t::Val,
um::Val;
dims,
invp::T=inv(p)) where {T, T1, T2, N}
return (x, mask, rng)
return dropout(rng, x, mask, p, t, um, invp; dims)
end

@doc doc"""
Expand Down Expand Up @@ -139,7 +148,7 @@ alpha_dropout(rng::AbstractRNG, x::AA, p, ::Val{false}, α, A, B) = (x, rng)
return tuple((i in dims ? si : 1 for (i, si) in enumerate(size(s)))...)
end

@inline _dropout_kernel(y, p, invp) = y > p ? invp : oftype(y, 0)
@inline _dropout_kernel(y, p, invp) = ifelse(y > p, invp, oftype(y, 0))

@inline _dropout_fptype(x) = float(real(eltype(x)))

Expand Down
10 changes: 5 additions & 5 deletions test/api/dropout.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ rng = get_stable_rng(12345)

fp16 = T == Float16
@eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu
@jet __f(x)
@jet sum(first(dropout(rng, x, T(0.5), Val(true); dims=Colon())))

@inferred dropout(rng, x, T(0.5), Val(true); dims=Colon())

Expand Down Expand Up @@ -66,7 +66,7 @@ end

fp16 = T == Float16
@eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu
@jet __f(x)
@jet sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(true); dims=Colon())))

# Try using mask if possible (possible!!)
@inferred dropout(rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon())
Expand All @@ -90,7 +90,7 @@ end

fp16 = T == Float16
@eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu
@jet __f(x)
@jet sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon())))

mask = rand(T, (x_shape[1:(end - 1)]..., 13)) |> aType

Expand All @@ -116,7 +116,7 @@ end

fp16 = T == Float16
@eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu
@jet __f(x)
@jet sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon())))

# Testing Mode
@inferred dropout(rng, x, mask, T(0.5), Val(false), Val(false); dims=Colon())
Expand Down Expand Up @@ -151,7 +151,7 @@ end

fp16 = T == Float16
@eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu
@jet __f(x)
@jet sum(first(alpha_dropout(rng, x, T(0.5), Val(true))))

@inferred alpha_dropout(rng, x, T(0.5), Val(false))

Expand Down

0 comments on commit 6c22c50

Please sign in to comment.