Skip to content

Commit

Permalink
Enzyme: add derivatives for ∇conv_filter and ∇conv_data (#582)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored May 11, 2024
1 parent d0da256 commit c8cbf76
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 6 deletions.
14 changes: 9 additions & 5 deletions ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@ using Random

using EnzymeCore.EnzymeRules

for (name, dataname, filtername) in ((typeof(NNlib.conv!), NNlib.∇conv_data!, NNlib.∇conv_filter!),
(typeof(NNlib.depthwiseconv!), NNlib.∇depthwiseconv_data!, NNlib.∇depthwiseconv_filter!) )
for (name, dataname, filtername) in (
(typeof(NNlib.conv!), NNlib.∇conv_data!, NNlib.∇conv_filter!),
(typeof(NNlib.depthwiseconv!), NNlib.∇depthwiseconv_data!, NNlib.∇depthwiseconv_filter!),
(typeof(NNlib.∇conv_data!), NNlib.conv!, NNlib.∇conv_filter!),
(typeof(NNlib.∇conv_filter!), NNlib.∇conv_data!, NNlib.conv!),
)
@eval begin

function EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{$name}, ::Type{RT},
Expand Down Expand Up @@ -84,11 +88,11 @@ for (name, dataname, filtername) in ((typeof(NNlib.conv!), NNlib.∇conv_data!,

if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val
# dx += grad wrt x.val
$dataname(dx, dy, cache_w, cdims.val; alpha=xT(1), beta=xT(1), kwargs...)
$dataname(dx, $(name != typeof(NNlib.∇conv_filter!) ? :dy : :cache_w), $(name != typeof(NNlib.∇conv_filter!) ? :cache_w : :dy), cdims.val; alpha=xT(1), beta=xT(1), kwargs...)
end
if !(typeof(w) <: EnzymeCore.Const) && dw !== w.val
# dw += grad wrt w.val
$filtername(dw, cache_x, dy, cdims.val; alpha=wT(1), beta=wT(1), kwargs...)
$filtername(dw, $(name != typeof(NNlib.∇conv_data!) ? :cache_x : :dy), $(name != typeof(NNlib.∇conv_data!) ? :dy : :cache_x), cdims.val; alpha=wT(1), beta=wT(1), kwargs...)
end

dy .= 0
Expand Down Expand Up @@ -379,4 +383,4 @@ function EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib._dropou
end


end
end
46 changes: 45 additions & 1 deletion test/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -910,6 +910,50 @@ end
end
end

@testset "EnzymeRules: ∇conv_data! spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3)
x = rand(rng, repeat([5], spatial_rank)..., 3, 2)
w = rand(rng, repeat([3], spatial_rank)..., 3, 3)
y = conv(x, w, cdims)

cdims = DenseConvDims(x, w)

curconv = ∇conv_data
curconv! = ∇conv_data!
dst = curconv(y, w, cdims)

for Tret in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated),
Tdst in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated),
Ty in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated),
Tw in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated)

EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tx, Tw) || continue

EnzymeTestUtils.test_reverse(curconv!, Tret, (dst, Tdst), (y, Ty), (w, Tw), (cdims, EnzymeCore.Const))
end
end

@testset "EnzymeRules: ∇conv_filter! spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3)
x = rand(rng, repeat([5], spatial_rank)..., 3, 2)
w = rand(rng, repeat([3], spatial_rank)..., 3, 3)
y = conv(x, w, cdims)

cdims = DenseConvDims(x, w)

curconv = ∇conv_filter
curconv! = ∇conv_filter!
dst = curconv(x, w, cdims)

for Tret in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated),
Tdst in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated),
Tx in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated),
Ty in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated)

EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tx, Ty) || continue

EnzymeTestUtils.test_reverse(curconv!, Tret, (dst, Tdst), (x, Tx), (y, Ty), (cdims, EnzymeCore.Const))
end
end

@testset "EnzymeRules: depthwiseconv! spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3)
x = rand(rng, repeat([5], spatial_rank)..., 3, 2)
w = rand(rng, repeat([3], spatial_rank)..., 3, 3)
Expand All @@ -931,4 +975,4 @@ end
end
end

end
end

0 comments on commit c8cbf76

Please sign in to comment.