Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
Merge pull request #35 from LuxDL/ap/forwarddiff
Browse files Browse the repository at this point in the history
Add Forward Mode rules for conv
  • Loading branch information
avik-pal authored Sep 15, 2023
2 parents c33f7be + 27c104a commit ba69eef
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 11 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LuxLib"
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.3.3"
version = "0.3.4"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
58 changes: 57 additions & 1 deletion ext/LuxLibForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,65 @@
module LuxLibForwardDiffExt

using ForwardDiff, LuxLib
import ForwardDiff: Dual

function LuxLib._dropout_fptype(x::AbstractArray{<:ForwardDiff.Dual})
# dropout
function LuxLib._dropout_fptype(x::AbstractArray{<:Dual})
return ForwardDiff.valtype(eltype(x))
end

# Convolutions: We might want to capture these furthur down in `conv!`
# NOTE: In principle we can concatenate all of the partials along the batch dimension
# and cut down substantially on the time to compute jacobians.
for op in [:conv, :depthwiseconv]
op! = Symbol("$(op)!")

@eval function NNlib.$(op)(x::AbstractArray{<:Dual{Tag, V, P}, N},
w::AbstractArray{<:Real, N}, cdims::ConvDims; kwargs...) where {N, Tag, V, P}
x_ = ForwardDiff.value.(x)

y = $(op)(x_, w, cdims; kwargs...)
dys = ntuple(i -> $(op)(ForwardDiff.partials.(x, i), w, cdims; kwargs...), P)

return map((yᵢ, dyᵢ...) -> Dual{Tag, V, P}(yᵢ, ForwardDiff.Partials(dyᵢ)), y,
dys...)
end

@eval function NNlib.$(op)(x::AbstractArray{<:Real, N},
w::AbstractArray{<:Dual{Tag, V, P}, N},
cdims::ConvDims; kwargs...) where {N, Tag, V, P}
w_ = ForwardDiff.value.(w)

y = $(op)(x, w_, cdims; kwargs...)
dys = ntuple(i -> $(op)(x, ForwardDiff.partials.(w, i), cdims; kwargs...), P)

return map((yᵢ, dyᵢ...) -> Dual{Tag, V, P}(yᵢ, ForwardDiff.Partials(dyᵢ)), y,
dys...)
end

@eval function NNlib.$(op)(x::AbstractArray{<:Dual{Tag, Vₓ, P}, N},
w::AbstractArray{<:Dual{Tag, Vₚ, P}, N}, cdims::ConvDims;
kwargs...) where {N, Tag, Vₓ, Vₚ, P}
x_ = ForwardDiff.value.(x)
w_ = ForwardDiff.value.(w)

y = $(op)(x_, w_, cdims; kwargs...)

dys₁ = ntuple(_ -> similar(x_, Vₓ, NNlib.output_size(cdims)...,
NNlib.channels_out(cdims), size(x, N)), P)
dys₂ = ntuple(_ -> similar(x_, Vₓ, NNlib.output_size(cdims)...,
NNlib.channels_out(cdims), size(x, N)), P)
for i in 1:P
$(op!)(dys₁[i], ForwardDiff.partials.(x, i), w_, cdims; kwargs...)
$(op!)(dys₂[i], x_, ForwardDiff.partials.(w, i), cdims; kwargs...)
dys₁[i] .+= dys₂[i]
end

# Technically it should `promote_type(Vₓ, Vₚ)` but this causes GPU compilation
# failure. We will assume it matches the type of the input.
return map((yᵢ, dyᵢ...) -> Dual{Tag, Vₓ, P}(yᵢ, ForwardDiff.Partials(dyᵢ)), y,
dys₁...)
end
end

end
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[deps]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
Expand Down
16 changes: 7 additions & 9 deletions test/ext/LuxLibForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,14 @@ include("../test_utils.jl")

rng = get_stable_rng(12345)

@testset "dropout" begin
if cpu_testing()
x = randn(rng, Float32, 10, 2)
x_dual = ForwardDiff.Dual.(x)
@testset "$mode: dropout" for (mode, aType, on_gpu) in MODES
x = randn(rng, Float32, 10, 2) |> aType
x_dual = ForwardDiff.Dual.(x)

@test_nowarn dropout(rng, x_dual, 0.5f0, Val(true); dims=:)
@test_nowarn dropout(rng, x_dual, 0.5f0, Val(true); dims=:)

x_dropout = dropout(rng, x, 0.5f0, Val(true); dims=:)[1]
x_dual_dropout = ForwardDiff.value.(dropout(rng, x_dual, 0.5f0, Val(true); dims=:)[1])
x_dropout = dropout(rng, x, 0.5f0, Val(true); dims=:)[1]
x_dual_dropout = ForwardDiff.value.(dropout(rng, x_dual, 0.5f0, Val(true); dims=:)[1])

@test check_approx(x_dropout, x_dual_dropout)
end
@test check_approx(x_dropout, x_dual_dropout)
end
74 changes: 74 additions & 0 deletions test/jvp.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
using LuxLib, ForwardDiff, Zygote, Test
using ComponentArrays

include("test_utils.jl")

struct LuxLibTestTag end

# Computes (∂f/∂x)u
function jvp_forwarddiff(f, x, u)
uu = reshape(u, axes(x))
y = ForwardDiff.Dual{typeof(ForwardDiff.Tag(LuxLibTestTag(), eltype(x))), eltype(x),
1}.(x, ForwardDiff.Partials.(tuple.(uu)))
return vec(ForwardDiff.partials.(vec(f(y)), 1))
end

function jvp_forwarddiff(f, x::ComponentArray, u)
xx = getdata(x)
uu = vec(u)
y = ComponentArray(ForwardDiff.Dual{typeof(ForwardDiff.Tag(LuxLibTestTag(),
eltype(x))), eltype(x), 1}.(xx, ForwardDiff.Partials.(tuple.(uu))),
getaxes(x))
return vec(ForwardDiff.partials.(vec(f(y)), 1))
end

## This exists exclusively for testing. It has horrifying performance implications
function jvp_forwarddiff_concrete(f, x, u)
Jₓ = ForwardDiff.jacobian(f, x)
return Jₓ * vec(u)
end

function jvp_zygote(f, x, u)
Jₓ = only(Zygote.jacobian(f, x))
return Jₓ * vec(u)
end

function test_jvp_computation(f, x, u)
jvp₁ = jvp_forwarddiff(f, x, u)
if !(x isa ComponentArray)
# ComponentArray + ForwardDiff on GPU don't play nice
jvp₂ = jvp_forwarddiff_concrete(f, x, u)
@test check_approx(jvp₁, jvp₂; atol=1e-5, rtol=1e-5)
end

jvp₃ = jvp_zygote(f, x, u)
@test check_approx(jvp₁, jvp₃; atol=1e-5, rtol=1e-5)
end

@testset "$mode: Jacobian Vector Products" for (mode, aType, on_gpu) in MODES
@testset "$(op)(; flipped = $flipped)" for flipped in (true, false),
op in (depthwiseconv, conv)

op === depthwiseconv && on_gpu && continue

input_dims = [(2, 4, 2, 1, 3), (4, 4, 1, 3), (4, 4, 3, 2), (4, 1, 3), (4, 3, 2)]
weight_dims = if op === conv
[(2, 2, 2, 1, 4), (3, 3, 1, 4), (3, 3, 3, 2), (3, 1, 4), (3, 3, 2)]
else
[(2, 2, 2, 1, 1), (3, 3, 1, 1), (3, 3, 3, 3), (3, 1, 1), (3, 3, 3)]
end

@testset "Input Dims: $(in_dims) | Weight Dims: $(w_dims)" for (in_dims, w_dims) in zip(input_dims,
weight_dims)
x = randn(Float32, in_dims...) |> aType
w = randn(Float32, w_dims...) |> aType
ux = randn(Float32, size(x)...) |> aType
uw = randn(Float32, size(w)...) |> aType
u = randn(Float32, length(x) + length(w)) |> aType

test_jvp_computation(x -> op(x, w; flipped), x, ux)
test_jvp_computation(w -> op(x, w; flipped), w, uw)
test_jvp_computation(xw -> op(xw.x, xw.w; flipped), ComponentArray(; x, w), u)
end
end
end
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ end
include("ext/LuxLibForwardDiffExt.jl")
end

@time @safetestset "Efficient Jacobian-Vector-Products" begin
include("jvp.jl")
end

if VERSION v"1.9"
@time @safetestset "Aqua Tests" begin
include("aqua.jl")
Expand Down
2 changes: 2 additions & 0 deletions test/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ using LuxLib, LuxTestUtils, StableRNGs, Test, Zygote
using LuxCUDA
using LuxTestUtils: @jet, @test_gradients, check_approx

CUDA.allowscalar(false)

const GROUP = get(ENV, "GROUP", "All")

cpu_testing() = GROUP == "All" || GROUP == "CPU"
Expand Down

2 comments on commit ba69eef

@avik-pal
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/91474

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.3.4 -m "<description of version>" ba69eef81b94fc51c2c859b29ec06646e4553a07
git push origin v0.3.4

Please sign in to comment.