diff --git a/ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl b/ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl index 7624463da..1894a585b 100644 --- a/ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl +++ b/ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl @@ -6,8 +6,12 @@ using Random using EnzymeCore.EnzymeRules -for (name, dataname, filtername) in ((typeof(NNlib.conv!), NNlib.∇conv_data!, NNlib.∇conv_filter!), - (typeof(NNlib.depthwiseconv!), NNlib.∇depthwiseconv_data!, NNlib.∇depthwiseconv_filter!) ) +for (name, dataname, filtername) in ( + (typeof(NNlib.conv!), NNlib.∇conv_data!, NNlib.∇conv_filter!), + (typeof(NNlib.depthwiseconv!), NNlib.∇depthwiseconv_data!, NNlib.∇depthwiseconv_filter!), + (typeof(NNlib.∇conv_data!), NNlib.conv!, NNlib.∇conv_filter!), + (typeof(NNlib.∇conv_filter!), NNlib.∇conv_data!, NNlib.conv!), + ) @eval begin function EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{$name}, ::Type{RT}, @@ -84,11 +88,11 @@ for (name, dataname, filtername) in ((typeof(NNlib.conv!), NNlib.∇conv_data!, if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val # dx += grad wrt x.val - $dataname(dx, dy, cache_w, cdims.val; alpha=xT(1), beta=xT(1), kwargs...) + $dataname(dx, $(name != typeof(NNlib.∇conv_filter!) ? :dy : :cache_w), $(name != typeof(NNlib.∇conv_filter!) ? :cache_w : :dy), cdims.val; alpha=xT(1), beta=xT(1), kwargs...) end if !(typeof(w) <: EnzymeCore.Const) && dw !== w.val # dw += grad wrt w.val - $filtername(dw, cache_x, dy, cdims.val; alpha=wT(1), beta=wT(1), kwargs...) + $filtername(dw, $(name != typeof(NNlib.∇conv_data!) ? :cache_x : :dy), $(name != typeof(NNlib.∇conv_data!) ? :dy : :cache_x), cdims.val; alpha=wT(1), beta=wT(1), kwargs...) end dy .= 0 @@ -379,4 +383,4 @@ function EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib._dropou end -end \ No newline at end of file +end diff --git a/test/conv.jl b/test/conv.jl index dc3fc57f5..dce01771a 100644 --- a/test/conv.jl +++ b/test/conv.jl @@ -910,6 +910,50 @@ end end end +@testset "EnzymeRules: ∇conv_data! spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3) + x = rand(rng, repeat([5], spatial_rank)..., 3, 2) + w = rand(rng, repeat([3], spatial_rank)..., 3, 3) + y = conv(x, w, cdims) + + cdims = DenseConvDims(x, w) + + curconv = ∇conv_data + curconv! = ∇conv_data! + dst = curconv(y, w, cdims) + + for Tret in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Tdst in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Ty in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Tw in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated) + + EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tx, Tw) || continue + + EnzymeTestUtils.test_reverse(curconv!, Tret, (dst, Tdst), (y, Ty), (w, Tw), (cdims, EnzymeCore.Const)) + end +end + +@testset "EnzymeRules: ∇conv_filter! spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3) + x = rand(rng, repeat([5], spatial_rank)..., 3, 2) + w = rand(rng, repeat([3], spatial_rank)..., 3, 3) + y = conv(x, w, cdims) + + cdims = DenseConvDims(x, w) + + curconv = ∇conv_filter + curconv! = ∇conv_filter! + dst = curconv(x, w, cdims) + + for Tret in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Tdst in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Tx in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated), + Ty in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated) + + EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tx, Ty) || continue + + EnzymeTestUtils.test_reverse(curconv!, Tret, (dst, Tdst), (x, Tx), (y, Ty), (cdims, EnzymeCore.Const)) + end +end + @testset "EnzymeRules: depthwiseconv! spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3) x = rand(rng, repeat([5], spatial_rank)..., 3, 2) w = rand(rng, repeat([3], spatial_rank)..., 3, 3) @@ -931,4 +975,4 @@ end end end -end \ No newline at end of file +end