Skip to content

Commit

Permalink
Merge pull request #387 from vincentmolin/rruleconv
Browse files Browse the repository at this point in the history
add `rrule(::typeof(∇conv_filter)`
  • Loading branch information
ToucheSir authored Feb 12, 2022
2 parents c0b4b8b + ae4866e commit 5fb8d27
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 0 deletions.
13 changes: 13 additions & 0 deletions src/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,19 @@ for conv in [:conv, :depthwiseconv]
end
end

function rrule(::typeof(∇conv_filter), x, dy, cdims; kw...)
function ∇conv_filter_pullback(Δ)
Δ1 = colmajor(unthunk(Δ))
return (
NoTangent(),
@thunk(∇conv_data(dy, Δ1, cdims, kw...)),
@thunk(conv(x, Δ1, cdims, kw...)),
NoTangent(),
)
end
return ∇conv_filter(x, dy, cdims; kw...), ∇conv_filter_pullback
end

# Use NNPACK if it is available and the operation is supported
# commented out 'till proper benchmarking and more correctness test are performed
# if is_nnpack_available()
Expand Down
4 changes: 4 additions & 0 deletions test/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -734,6 +734,10 @@ end
# else
gradtest((y, w) -> sum(∇conv_data(y, w, cdims)), y, w)
# end
gradtest((x, y) -> ∇conv_filter(x, y, cdims), x, y)
if spatial_rank < 3
gradtest((x, y) -> sum(∇conv_filter(x, y, cdims)), x, y)
end

dcdims = DepthwiseConvDims(x, w)
gradtest((x, w) -> depthwiseconv(x, w, dcdims), x, w)
Expand Down

0 comments on commit 5fb8d27

Please sign in to comment.