Skip to content

Commit

Permalink
Fix workspace size calculation for fwd conv (#678)
Browse files Browse the repository at this point in the history
  • Loading branch information
pxl-th authored Sep 24, 2024
1 parent 1b2e742 commit 8a23d0d
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions src/dnn/convolution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,18 @@ get_workspace_size_func(::Type{miopenConvBwdWeightsAlgorithm_t}) = miopenConvolu
get_workspace_size_func(::Type{miopenConvBwdDataAlgorithm_t}) = miopenConvolutionBackwardDataGetWorkSpaceSize

function get_workspace_size(
conv_type::C; handle, a_desc, b_desc, conv_desc, c_desc,
conv_type::C; handle,
# fwd: x, w, y -> w, x, y
# bwd weight: dy, x, dw -> dy, x, dw
# bwd data: dy, w, dx -> dy, w, dx
a_desc, b_desc, conv_desc::ConvolutionDescriptor, c_desc,
) where C <: CONV_ALGOS
args = conv_type == miopenConvFwdAlgorithm_t ?
# NOTE swap first two args for fwd
(b_desc.handle, a_desc.handle, conv_desc.handle, c_desc.handle) :
(a_desc.handle, b_desc.handle, conv_desc.handle, c_desc.handle)
wsize_ref = Ref{Csize_t}(0)
get_workspace_size_func(conv_type)(
handle, a_desc.handle, b_desc.handle,
conv_desc.handle, c_desc.handle, wsize_ref) # NOTE: do not...
get_workspace_size_func(conv_type)(handle, args..., wsize_ref) # NOTE: do not check
wsize_ref[]
end

Expand Down

0 comments on commit 8a23d0d

Please sign in to comment.