Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
Merge pull request #21 from LuxDL/ap/jet_fails
Browse files Browse the repository at this point in the history
Fix JET Failures
  • Loading branch information
avik-pal authored Jun 21, 2023
2 parents 0fe39c0 + 6c22c50 commit d3274eb
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 26 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LuxLib"
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.2.3"
version = "0.2.4"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
49 changes: 29 additions & 20 deletions src/api/dropout.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
@doc doc"""
dropout(rng::AbstractRNG, x, p, ::Val{training}; dims, invp=inv(p))
dropout(rng::AbstractRNG, x, mask, p, ::Val{training}, ::Val{update_mask}; dims,
invp=inv(p))
dropout(rng::AbstractRNG, x, p, ::Val{training}, invp; dims)
dropout(rng::AbstractRNG, x, mask, p, ::Val{training}, ::Val{update_mask}, invp;
dims)
Dropout: Simple Way to prevent Neural Networks for Overfitting. For details see [1].
Expand All @@ -15,6 +15,7 @@ Dropout: Simple Way to prevent Neural Networks for Overfitting. For details see
`dims`. Else, `x` is returned
- `Val(update_mask)`: If `true` then the mask is generated and used. Else, the `mask`
provided is directly used
- `invp`: Inverse of the probability
## Keyword Arguments
Expand All @@ -32,29 +33,26 @@ Dropout: Simple Way to prevent Neural Networks for Overfitting. For details see
[1] Srivastava, Nitish, et al. "Dropout: a simple way to prevent neural networks from
overfitting." The journal of machine learning research 15.1 (2014): 1929-1958.
"""
function dropout(rng::AbstractRNG, x::AA, p::T, ::Val{true}; dims, invp::T=inv(p)) where {T}
function dropout(rng::AbstractRNG, x::AA, p::T, ::Val{true}, invp::T; dims) where {T}
rng = _replicate(rng)
mask = _generate_dropout_mask(rng, x, p, invp; dims)
return (x .* ignore_derivatives(mask), mask, rng)
end

function dropout(rng::AbstractRNG,
x::AA,
p::T,
::Val{false};
dims,
invp::T=inv(p)) where {T}
return (x, x, rng)
dropout(rng::AbstractRNG, x::AA, p::T, ::Val{false}, ::T; dims) where {T} = (x, x, rng)

function dropout(rng::AbstractRNG, x::AA, p::T, t::Val; dims, invp::T=inv(p)) where {T}
return dropout(rng, x, p, t, invp; dims)
end

function dropout(rng::AbstractRNG,
x::AA,
mask::AA,
p::T,
t::Val,
::Val{true};
dims,
invp::T=inv(p)) where {T}
::Val{true},
invp::T;
dims) where {T}
return dropout(rng, x, p, t; dims, invp)
end

Expand All @@ -63,9 +61,9 @@ function dropout(rng::AbstractRNG,
mask::AA{T2, N},
p::T,
::Val{true},
::Val{false};
dims,
invp::T=inv(p)) where {T, T1, T2, N}
::Val{false},
invp::T;
dims) where {T, T1, T2, N}
size(x) != size(mask) && return dropout(rng, x, p, Val(true); dims, invp)
return x .* ignore_derivatives(mask), mask, rng
end
Expand All @@ -75,10 +73,21 @@ function dropout(rng::AbstractRNG,
mask::AA{T2, N},
p::T,
::Val{false},
::Val{false};
::Val{false},
invp::T;
dims) where {T, T1, T2, N}
return (x, mask, rng)
end

function dropout(rng::AbstractRNG,
x::AA{T1, N},
mask::AA{T2, N},
p::T,
t::Val,
um::Val;
dims,
invp::T=inv(p)) where {T, T1, T2, N}
return (x, mask, rng)
return dropout(rng, x, mask, p, t, um, invp; dims)
end

@doc doc"""
Expand Down Expand Up @@ -139,7 +148,7 @@ alpha_dropout(rng::AbstractRNG, x::AA, p, ::Val{false}, α, A, B) = (x, rng)
return tuple((i in dims ? si : 1 for (i, si) in enumerate(size(s)))...)
end

@inline _dropout_kernel(y, p, invp) = y > p ? invp : oftype(y, 0)
@inline _dropout_kernel(y, p, invp) = ifelse(y > p, invp, oftype(y, 0))

@inline _dropout_fptype(x) = float(real(eltype(x)))

Expand Down
10 changes: 5 additions & 5 deletions test/api/dropout.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ rng = get_stable_rng(12345)

fp16 = T == Float16
@eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu
@jet __f(x)
@jet sum(first(dropout(rng, x, T(0.5), Val(true); dims=Colon())))

@inferred dropout(rng, x, T(0.5), Val(true); dims=Colon())

Expand Down Expand Up @@ -66,7 +66,7 @@ end

fp16 = T == Float16
@eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu
@jet __f(x)
@jet sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(true); dims=Colon())))

# Try using mask if possible (possible!!)
@inferred dropout(rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon())
Expand All @@ -90,7 +90,7 @@ end

fp16 = T == Float16
@eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu
@jet __f(x)
@jet sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon())))

mask = rand(T, (x_shape[1:(end - 1)]..., 13)) |> aType

Expand All @@ -116,7 +116,7 @@ end

fp16 = T == Float16
@eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu
@jet __f(x)
@jet sum(first(dropout(rng, x, mask, T(0.5), Val(true), Val(false); dims=Colon())))

# Testing Mode
@inferred dropout(rng, x, mask, T(0.5), Val(false), Val(false); dims=Colon())
Expand Down Expand Up @@ -151,7 +151,7 @@ end

fp16 = T == Float16
@eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 soft_fail=$fp16 gpu_testing=$on_gpu
@jet __f(x)
@jet sum(first(alpha_dropout(rng, x, T(0.5), Val(true))))

@inferred alpha_dropout(rng, x, T(0.5), Val(false))

Expand Down

2 comments on commit d3274eb

@avik-pal
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/86070

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.4 -m "<description of version>" d3274eb2f49533668e77fcda0ec833057943a91d
git push origin v0.2.4

Please sign in to comment.