diff --git a/Project.toml b/Project.toml index f19d965b1..e9880b23a 100644 --- a/Project.toml +++ b/Project.toml @@ -9,9 +9,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Requires = "ae029012-a4dd-5104-9daa-d747884805df" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [weakdeps] @@ -19,6 +17,7 @@ AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [extensions] @@ -27,6 +26,7 @@ NNlibCUDACUDNNExt = ["CUDA", "cuDNN"] NNlibCUDAExt = "CUDA" NNlibEnzymeCoreExt = "EnzymeCore" NNlibFFTWExt = "FFTW" +NNlibForwardDiffExt = "ForwardDiff" [compat] AMDGPU = "0.9.4, 1" @@ -36,12 +36,11 @@ CUDA = "4, 5" ChainRulesCore = "1.13" EnzymeCore = "0.5, 0.6, 0.7" FFTW = "1.8.0" +ForwardDiff = "0.10.36" GPUArraysCore = "0.1" KernelAbstractions = "0.9.2" LinearAlgebra = "<0.0.1, 1" -Pkg = "<0.0.1, 1" Random = "<0.0.1, 1" -Requires = "1.0" Statistics = "1" cuDNN = "1" julia = "1.9" diff --git a/benchmark/perf_report.jl b/benchmark/perf_report.jl index 5c06515eb..9b861e869 100644 --- a/benchmark/perf_report.jl +++ b/benchmark/perf_report.jl @@ -37,10 +37,6 @@ for rank in (2,), (NNlib.depthwiseconv_im2col!, NNlib.∇depthwiseconv_data_im2col!, NNlib.∇depthwiseconv_filter_im2col!, DepthwiseConvDims, "im2col"), ] - if NNlib.is_nnpack_available() - push!(benchmark_items, (NNlib.conv_nnpack!, NNlib.∇conv_data_nnpack!, NNlib.∇conv_filter_nnpack!, DenseConvDims, "nnpack")) - end - for (conv!, ∇conv_data!, ∇conv_filter!, cT, backend) in benchmark_items x = zeros(Float32, repeat([N], rank)..., C_in, 1) @@ -105,15 +101,4 @@ for rank in (2,), @show(pdims) @save "results.jld2" results end - - if NNlib.is_nnpack_available() - if NNlib.nnpack_supported_operation(pdims) - t_fwd = @benchmark NNlib.maxpool_nnpack!($y, $x, $pdims) - - add_result(t_fwd, "maxpool2d", "nnpack", pdims) - - @show(pdims) - @save "results.jld2" results - end - end end diff --git a/ext/NNlibForwardDiffExt.jl b/ext/NNlibForwardDiffExt.jl new file mode 100644 index 000000000..84351bc17 --- /dev/null +++ b/ext/NNlibForwardDiffExt.jl @@ -0,0 +1,9 @@ +module NNlibForwardDiffExt + +using ForwardDiff: ForwardDiff +using NNlib: NNlib + +NNlib.within_gradient(x::ForwardDiff.Dual) = true +NNlib.within_gradient(x::AbstractArray{<:ForwardDiff.Dual}) = true + +end diff --git a/src/NNlib.jl b/src/NNlib.jl index 8cf66370f..687206fca 100644 --- a/src/NNlib.jl +++ b/src/NNlib.jl @@ -12,9 +12,7 @@ using KernelAbstractions: @atomic using LinearAlgebra using LinearAlgebra.BLAS: @blasfunc, BlasInt using LinearAlgebra: AdjOrTransAbsMat, Adjoint, BlasFloat, Transpose -using Pkg using Random -using Requires using Statistics using Statistics: mean @@ -24,19 +22,6 @@ const Numeric = Union{AbstractArray{<:T}, T} where {T<:Number} include("dim_helpers.jl") export ConvDims, DenseConvDims, PoolDims, DepthwiseConvDims -is_nnpack_available() = false - -@init @require NNPACK_jll="a6bfbf70-4841-5cb9-aa18-3a8ad3c413ee" begin - if isdefined(NNPACK_jll, :libnnpack) - include("nnpack/NNPACK.jl") - else - @warn "NNPACK not available for your platform: " * - "$( Pkg.BinaryPlatforms.platform_name(Pkg.BinaryPlatforms.platform_key_abi()))" * - "($( Pkg.BinaryPlatforms.triplet(Pkg.BinaryPlatforms.platform_key_abi()))) - You will be able to use only the default Julia NNlib backend" - end -end - include("activations.jl") for f in ACTIVATIONS @eval export $(f) @@ -95,11 +80,6 @@ export upsample_nearest, ∇upsample_nearest, include("gather.jl") include("scatter.jl") include("utils.jl") -@init @require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" begin - using .ForwardDiff - within_gradient(x::ForwardDiff.Dual) = true - within_gradient(x::AbstractArray{<:ForwardDiff.Dual}) = true -end include("sampling.jl") include("functions.jl") diff --git a/src/conv.jl b/src/conv.jl index 3fecb9151..fead2ee21 100644 --- a/src/conv.jl +++ b/src/conv.jl @@ -76,7 +76,7 @@ end # Let's generate auto-allocating versions of all our functions, for all backends. # We `@timeit` these methods separately, as we want to know how much time is spent in # allocation. :P -for backend in (Symbol(), :_direct, :_im2col, :_nnpack) +for backend in (Symbol(), :_direct, :_im2col) # First make auto-allocating versions of the conv()-like calls: for name in (:conv, :depthwiseconv) @eval begin @@ -134,7 +134,7 @@ end # since we can specialize on sizes. for front_name in (:conv, :∇conv_data, :∇conv_filter, :depthwiseconv, :∇depthwiseconv_data, :∇depthwiseconv_filter) - for backend in (Symbol(), :_direct, :_im2col) ## NNPACK is only for 2d conv + for backend in (Symbol(), :_direct, :_im2col) for N in (3, 4) @eval begin function $(Symbol("$(front_name)$(backend)!"))( @@ -381,26 +381,3 @@ function rrule(::typeof(∇conv_filter), x, dy, cdims; kw...) end return ∇conv_filter(x, dy, cdims; kw...), ∇conv_filter_pullback end - -# Use NNPACK if it is available and the operation is supported -# commented out 'till proper benchmarking and more correctness test are performed -# if is_nnpack_available() -# function conv(x::Array{Float32, 4}, w::Array{Float32, 4}, -# cdims::DenseConvDims{2, K, C_in, C_out, (1, 1), P, (1, 1), F}; -# kwargs...) where {K, C_in, C_out, P, F} -# return conv_nnpack(x, w, cdims; kwargs...) -# end - -# function ∇conv_data(dy::Array{Float32, 4}, w::Array{Float32, 4}, -# cdims::DenseConvDims{2, K, C_in, C_out, (1, 1), P, (1, 1), F}; -# kwargs...) where {K, C_in, C_out, P, F} -# return ∇conv_data_nnpack(dy, w, cdims; kwargs...) -# end - -# function ∇conv_filter(x::Array{Float32, 4}, dy::Array{Float32, 4}, -# cdims::DenseConvDims{2, K, C_in, C_out, (1, 1), P, (1, 1), F}; -# kwargs...) where {K, C_in, C_out, P, F} -# return ∇conv_filter_nnpack(x, dy, cdims; kwargs...) -# end -# end -######################################################## diff --git a/src/nnpack/NNPACK.jl b/src/nnpack/NNPACK.jl deleted file mode 100644 index 685415a7e..000000000 --- a/src/nnpack/NNPACK.jl +++ /dev/null @@ -1,55 +0,0 @@ -using NNPACK_jll - -include("libnnpack_types.jl") -include("error.jl") -include("libnnpack.jl") -include("performance.jl") -include("interface.jl") - - -const shared_threadpool_dict = Dict{UInt64, Base.RefValue}() - -""" - is_nnpack_available() - -Checks if the current hardware is supported by NNPACK. -""" -function is_nnpack_available() - status = nnp_initialize() - if status == nnp_status_unsupported_hardware - return false - else - return true - end -end - -""" - allocate_threadpool() - -Allocates several threadpool based on the upper limit on the number of threads for the machine. -Allows NNPACK to intelligently choose which threadpool to use for getting the best -performance. -""" -function allocate_threadpool() - global NNPACK_CPU_THREADS = NNPACK_CPU_THREADS > 8 ? UInt64(8) : UInt64(exp2(floor(log2(NNPACK_CPU_THREADS)))) - for i in 0:Int(log2(NNPACK_CPU_THREADS)) - threads = UInt64(2^i) - push!(shared_threadpool_dict, threads => Ref(pthreadpool_create(threads))) - end -end - -@init begin - status = nnp_initialize() - if status == nnp_status_unsupported_hardware - @warn "Hardware is unsupported by NNPACK so falling back to default NNlib" - end - try - global NNPACK_CPU_THREADS = parse(UInt64, ENV["NNPACK_CPU_THREADS"]) - catch - # Sys.CPU_THREADS should be a better default if we are tuning the benchmark suite on - # a particular machine. However, we fix the runtime threadpool here to have a max of - # 4 threads so anything above will be ignored anyways - global NNPACK_CPU_THREADS = UInt64(4) - end - allocate_threadpool() -end diff --git a/src/nnpack/error.jl b/src/nnpack/error.jl deleted file mode 100644 index 83522c37d..000000000 --- a/src/nnpack/error.jl +++ /dev/null @@ -1,83 +0,0 @@ -struct NNPACKError <: Exception - code::nnp_status - msg::AbstractString -end - -Base.show(io::IO, err::NNPACKError) = print(io, "NNPACKError(code $(err.code), $(err.msg))") - -function NNPACKError(status::nnp_status) - msg = "NNPACK STATUS SUCCESS" - if status == nnp_status_invalid_batch_size - msg = "NNPACK STATUS INVALID BATCH SIZE" - elseif status == nnp_status_invalid_channels - msg = "NNPACK STATUS INVALID CHANNELS" - elseif status == nnp_status_invalid_input_channels - msg = "NNPACK STATUS INVALID INPUT CHANNELS" - elseif status == nnp_status_invalid_output_channels - msg = "NNPACK STATUS INVALID OUTPUT CHANNELS" - elseif status == nnp_status_invalid_input_size - msg = "NNPACK STATUS INVALID INPUT SIZE" - elseif status == nnp_status_invalid_input_stride - msg = "NNPACK STATUS INVALID INPUT STRIDE" - elseif status == nnp_status_invalid_input_padding - msg = "NNPACK STATUS INVALID INPUT PADDING" - elseif status == nnp_status_invalid_kernel_size - msg = "NNPACK STATUS INVALID KERNEL SIZE" - elseif status == nnp_status_invalid_pooling_size - msg = "NNPACK STATUS INVALID POOLING SIZE" - elseif status == nnp_status_invalid_pooling_stride - msg = "NNPACK STATUS INVALID POOLING STRIDE" - elseif status == nnp_status_invalid_algorithm - msg = "NNPACK STATUS INVALID ALGORITHM" - elseif status == nnp_status_invalid_transform_strategy - msg = "NNPACK STATUS INVALID TRANSFORM STRATEGY" - elseif status == nnp_status_invalid_output_subsampling - msg = "NNPACK STATUS INVALID OUTPUT SUBSAMPLING" - elseif status == nnp_status_invalid_activation - msg = "NNPACK STATUS INVALID ACTIVATION" - elseif status == nnp_status_invalid_activation_parameters - msg = "NNPACK STATUS INVALID ACTIVATION PARAMETERS" - elseif status == nnp_status_unsupported_input_size - msg = "NNPACK STATUS UNSUPPORTED INPUT SIZE" - elseif status == nnp_status_unsupported_input_stride - msg = "NNPACK STATUS UNSUPPORTED INPUT STRIDE" - elseif status == nnp_status_unsupported_input_padding - msg = "NNPACK STATUS UNSUPPORTED INPUT PADDING" - elseif status == nnp_status_unsupported_kernel_size - msg = "NNPACK STATUS UNSUPPORTED KERNEL SIZE" - elseif status == nnp_status_unsupported_pooling_size - msg = "NNPACK STATUS UNSUPPORTED POOLING SIZE" - elseif status == nnp_status_unsupported_pooling_stride - msg = "NNPACK STATUS UNSUPPORTED POOLING STRIDE" - elseif status == nnp_status_unsupported_algorithm - msg = "NNPACK STATUS UNSUPPORTED ALGORITHM" - elseif status == nnp_status_unsupported_transform_strategy - msg = "NNPACK STATUS UNSUPPORTED TRANSFORM STRATEGY" - elseif status == nnp_status_unsupported_activation - msg = "NNPACK STATUS UNSUPPORTED ACTIVATION" - elseif status == nnp_status_unsupported_activation_parameters - msg = "NNPACK STATUS UNSUPPORTED ACTIVATION PARAMETERS" - elseif status == nnp_status_uninitialized - msg = "NNPACK STATUS UNINITIALIZED" - elseif status == nnp_status_unsupported_hardware - msg = "NNPACK STATUS UNSUPPORTED HARDWARE" - elseif status == nnp_status_out_of_memory - msg = "NNPACK STATUS OUT OF MEMORY" - elseif status == nnp_status_insufficient_buffer - msg = "NNPACK STATUS INSUFFICIENT BUFFER" - elseif status == nnp_status_misaligned_buffer - msg = "NNPACK STATUS MISALIGNED BUFFER" - end - NNPACKError(status, msg) -end - -macro nnpack_check(nnp_func) - quote - local err::nnp_status - err = $(esc(nnp_func)) - if err != nnp_status_success - throw(NNPACKError(err)) - end - err - end -end diff --git a/src/nnpack/impl.jl b/src/nnpack/impl.jl deleted file mode 100644 index 3309404e1..000000000 --- a/src/nnpack/impl.jl +++ /dev/null @@ -1,50 +0,0 @@ -function maxpool_nnpack!(y::A, x::A, pdims::PoolDims) where {A<:Array{Float32, 4}} - check_dims(size(x), size(y), pdims) - threadpool = select_threadpool(pdims, size(y, 4)) - nnp_max_pooling_output(y, x, kernel_size(pdims), padding = padding(pdims), - stride = stride(pdims), threadpool = threadpool) -end - -function conv_nnpack!(y::A1, x::A1, w::A1, cdims::ConvDims; - b::A2 = zeros(Float32, size(x, 3)), - algo = UInt32(0)) where {A1<:Array{Float32, 4}, - A2<:Array{Float32, 1}} - check_dims(size(x), size(w), size(y), cdims) - threadpool = select_threadpool(cdims, size(y, 4)) - - if flipkernel(cdims) == 0 - w = flipweight(w) - end - - nnp_convolution_output(y, x, w, b, algo = algo, padding = padding(cdims), - stride = stride(cdims), threadpool = threadpool) -end - -function ∇conv_data_nnpack!(dx::A, dy::A, w::A, cdims::ConvDims; - algo = UInt32(0)) where{A<:Array{Float32, 4}} - check_dims(size(dx), size(w), size(dy), cdims) - threadpool = select_threadpool(cdims, size(dy, 4)) - - if flipkernel(cdims) == 0 - w = flipweight(w) - end - - nnp_convolution_input_gradient(dx, dy, w, algo = algo, padding = padding(cdims), - stride = stride(cdims), threadpool = threadpool) -end - -function ∇conv_filter_nnpack!(dw::A, x::A, dy::A, cdims::ConvDims; - algo = UInt32(0)) where{A<:Array{Float32, 4}} - check_dims(size(x), size(dw), size(dy), cdims) - threadpool = select_threadpool(cdims, size(dy, 4)) - - nnp_convolution_kernel_gradient(dw, x, dy, algo = algo, padding = padding(cdims), - stride = stride(cdims), threadpool = threadpool) - - if flipkernel(cdims) == 0 - dw .= flipweight(dw) - end - - dw -end - diff --git a/src/nnpack/interface.jl b/src/nnpack/interface.jl deleted file mode 100644 index 5cdaccb4d..000000000 --- a/src/nnpack/interface.jl +++ /dev/null @@ -1,44 +0,0 @@ -include("impl.jl") - -## NNPACK supports only Float32 -for (front_name, backend) in ( - :conv => :_nnpack, - :∇conv_data => :_nnpack, - :∇conv_filter => :_nnpack, - ) - @eval begin - function $(Symbol("$(front_name)$(backend)!"))( - out::Array{T1,4}, in1::Array{T2,4}, in2::Array{T3,4}, - cdims::ConvDims; kwargs...) where {T1, T2, T3} - @warn "Automatically converting input tensor to Float32. This will have performance implications" maxlog=1 - # Output must of the same type as in the function signature - T1.($(Symbol("$(front_name)$(backend)!"))(Float32.(out), Float32.(in1), - Float32.(in2), cdims; kwargs...)) - end - end -end - -function maxpool_nnpack!(y::Array{T1, 4}, x::Array{T2, 4}, pdims::PoolDims; - kwargs...) where {T1, T2} - @warn "Automatically converting input tensor to Float32. This will have performance implications" maxlog=1 - # We want the output to be of the same type as desired - T1.(maxpool_nnpack!(Float32.(y), Float32.(x), pdims; kwargs...)) -end - -""" - nnpack_supported_operation(cdims::ConvDims) - nnpack_supported_operation(pdims::PoolDims) - -Returns `true` if nnpack supports the convolution/pooling operation for the given parameters. -""" -function nnpack_supported_operation(pdims::PoolDims{2, K, S, P, (1, 1)}) where {K, S, P} - val = input_size(pdims)[1:2] .+ (P[1] + P[2], P[3] + P[4]) .- K - return val .% S == (0, 0) ? true : false -end - -function nnpack_supported_operation(cdims::ConvDims{2, K, (1, 1), P, (1, 1)}) where {K, S, P} - return true -end - -# Return false for everything else -nnpack_supported_operation(dims) = false diff --git a/src/nnpack/libnnpack.jl b/src/nnpack/libnnpack.jl deleted file mode 100644 index 2f3996c32..000000000 --- a/src/nnpack/libnnpack.jl +++ /dev/null @@ -1,135 +0,0 @@ -#NOTE: We do the error handling of nnp_initialize while loading NNPACK -function nnp_initialize() - ccall((:nnp_initialize, libnnpack), nnp_status, (),) -end - -function nnp_deinitialize() - @nnpack_check ccall((:nnp_deinitialize, libnnpack), nnp_status, (),) -end - -function pthreadpool_create(n = 0) - ccall((:pthreadpool_create, libnnpack), Ptr{Cvoid}, (Csize_t,), n) -end - -function nnp_relu_output(batch_size, channels, input, output, negative_slope, threadpool) - @nnpack_check ccall((:nnp_relu_output, libnnpack), nnp_status, (Csize_t, Csize_t, Ptr{Cfloat}, Ptr{Cfloat}, Cfloat, pthreadpool_t), batch_size, channels, input, output, negative_slope, threadpool) -end - -function nnp_relu_output(x::Array{Float32,N}, y::Array{Float32,N}; negative_slope::AbstractFloat = 0.0, threadpool = C_NULL) where {N} - # Investigate why the channel and batch dims need to specified like this - nnp_relu_output(prod(size(x)[N-1:N]), prod(size(x)[1:N-2]), x, y, negative_slope, threadpool) - y -end - -function nnp_relu_input_gradient(batch_size, channels, grad_output, input, grad_input, negative_slope, threadpool) - @nnpack_check ccall((:nnp_relu_input_gradient, libnnpack), nnp_status, (Csize_t, Csize_t, Ptr{Cfloat}, Ptr{Cfloat}, Ptr{Cfloat}, Cfloat, pthreadpool_t), batch_size, channels, grad_output, input, grad_input, negative_slope, threadpool) -end - -function nnp_relu_input_gradient(x::Array{Float32,N}, dy::Array{Float32,N}, dx::Array{Float32,N}; negative_slope::AbstractFloat = 0.0, threadpool = C_NULL) where {N} - # Investigate why the channel and batch dims need to specified like this - nnp_relu_input_gradient(Csize_t(prod(size(x)[N-1:N])), prod(size(x)[1:N-2]), dy, x, dx, negative_slope, threadpool) - dx -end - -function nnp_softmax_output(batch_size, channels, input, output, threadpool) - @nnpack_check ccall((:nnp_softmax_output, libnnpack), nnp_status, (Csize_t, Csize_t, Ptr{Cfloat}, Ptr{Cfloat}, pthreadpool_t), batch_size, channels, input, output, threadpool) -end - -function nnp_softmax_output(x::VecOrMat{Float32}, y::VecOrMat{Float32}; threadpool = C_NULL) - nnp_softmax_output(ndims(x) == 2 ? size(x, 2) : 1, size(x, 1), x, y, threadpool) - y -end - -#FIXME: Output of fully connected not consistent with `kernel * input` -#NOTE: This most likely due to nnpack being row major. Investigate this. - -function nnp_fully_connected_output(batch_size, input_channels, output_channels, input, kernel, output, threadpool, profile) - @nnpack_check ccall((:nnp_fully_connected_output, libnnpack), nnp_status, (Csize_t, Csize_t, Csize_t, Ptr{Cfloat}, Ptr{Cfloat}, Ptr{Cfloat}, pthreadpool_t, Ptr{Cvoid}), batch_size, input_channels, output_channels, input, kernel, output, threadpool, C_NULL) -end - -function nnp_fully_connected_output(x::Array{Float32,2}, w::Array{Float32,2}, y::Array{Float32,2}; profile = nothing, threadpool = C_NULL) - profile = profile == nothing ? nnp_profile() : profile - nnp_fully_connected_output(size(x, 2), size(x, 1), size(w, 1), x, w, y, threadpool, profile) - y -end - -function nnp_fully_connected_inference_f16f32(input_channels, output_channels, input, kernel, output, threadpool) - @nnpack_check ccall((:nnp_fully_connected_inference_f16f32, libnnpack), nnp_status, (Csize_t, Csize_t, Ptr{Cfloat}, Ptr{Cvoid}, Ptr{Cfloat}, pthreadpool_t), input_channels, output_channels, input, kernel, output, threadpool) -end - -nnp_fully_connected_inference_f16f32(x::Array{Float32, 1}, w::Array{Float16,2}, y::Array{Float32, 1}; threadpool = C_NULL) = - nnp_fully_connected_inference(reshape(x, size(x), 1), w, reshape(y, size(y), 1), threadpool = threadpool) - -function nnp_fully_connected_inference_f16f32(x::Array{Float32, 2}, w::Array{Float16,2}, y::Array{Float32, 2}; threadpool = C_NULL) - nnp_fully_connected_inference(size(x, 1), size(y, 1), x, w, y, threadpool) - y -end - -function nnp_fully_connected_inference(input_channels, output_channels, input, kernel, output, threadpool) - @nnpack_check ccall((:nnp_fully_connected_inference, libnnpack), nnp_status, (Csize_t, Csize_t, Ptr{Cfloat}, Ptr{Cfloat}, Ptr{Cfloat}, pthreadpool_t), input_channels, output_channels, input, kernel, output, threadpool) -end - -nnp_fully_connected_inference(x::Array{Float32, 1}, w::Array{Float32,2}; threadpool = C_NULL) = - nnp_fully_connected_inference(reshape(x, size(x), 1), w, threadpool = threadpool) - -function nnp_fully_connected_inference(x::Array{Float32, 2}, w::Array{Float32, 2}, y::Array{Float32, 2}; threadpool = C_NULL) - nnp_fully_connected_inference(size(x, 1), size(y, 1), x, w, y, threadpool) - y -end - -function nnp_max_pooling_output(batch_size, channels, input_size, input_padding, pooling_size, pooling_stride, input, output, threadpool) - @nnpack_check ccall((:nnp_max_pooling_output, libnnpack), nnp_status, (Csize_t, Csize_t, nnp_size, nnp_padding, nnp_size, nnp_size, Ptr{Cfloat}, Ptr{Cfloat}, pthreadpool_t), batch_size, channels, input_size, input_padding, pooling_size, pooling_stride, input, output, threadpool) -end - -function nnp_max_pooling_output(y::Array{Float32,4}, x::Array{Float32,4}, kernel::Tuple; padding = 0, stride = 1, threadpool = C_NULL) - input_size = nnp_size(Csize_t.((size(x, 1), size(x, 2)))...) - pooling_size = nnp_size(Csize_t.(kernel)...) - input_padding = nnp_padding(Csize_t(padding[2]), Csize_t(padding[1]), Csize_t(padding[2]), Csize_t(padding[1])) - pooling_stride = nnp_size(Csize_t.(stride)...) - nnp_max_pooling_output(size(x, 4), size(x, 3), input_size, input_padding, pooling_size, pooling_stride, x, y, threadpool) - y -end - -#TODO: Add wrapper for convolution inference - -function nnp_convolution_input_gradient(algorithm, batch_size, input_channels, output_channels, input_size, input_padding, kernel_size, grad_output, kernel, grad_input, workspace_buffer, workspace_size, activation, activation_parameters, threadpool, profile) - @nnpack_check ccall((:nnp_convolution_input_gradient, libnnpack), nnp_status, (nnp_convolution_algorithm, Csize_t, Csize_t, Csize_t, nnp_size, nnp_padding, nnp_size, Ptr{Cfloat}, Ptr{Cfloat}, Ptr{Cfloat}, Ptr{Cvoid}, Csize_t, nnp_activation, Ptr{Cvoid}, pthreadpool_t, Ptr{Cvoid}), algorithm, batch_size, input_channels, output_channels, input_size, input_padding, kernel_size, grad_output, kernel, grad_input, workspace_buffer, workspace_size, activation, activation_parameters, threadpool, C_NULL) -end - -function nnp_convolution_input_gradient(dx::Array{Float32,4}, dy::Array{Float32,4}, w::Array{Float32,4}; algo::nnp_convolution_algorithm = UInt32(0), workspace_buffer = nothing, workspace_size = 0, padding = 0, stride = 1, threadpool = C_NULL, profile = nothing) - input_size = nnp_size(Csize_t.((size(dx,1), size(dx,2)))...) - kernel_size = nnp_size(Csize_t.((size(w,1),size(w,2)))...) - input_padding = nnp_padding(Csize_t(padding[2]), Csize_t(padding[1]), Csize_t(padding[2]), Csize_t(padding[1])) - profile = profile == nothing ? nnp_profile() : profile - workspace_buffer = workspace_buffer === nothing ? C_NULL : workspace_buffer - nnp_convolution_input_gradient(UInt32(algo), size(dx,4), size(dx,3), size(w,4), input_size, input_padding, kernel_size, dy, w, dx, workspace_buffer, workspace_size, UInt32(0), C_NULL, threadpool, profile) - dx -end - -function nnp_convolution_kernel_gradient(algorithm, batch_size, input_channels, output_channels, input_size, input_padding, kernel_size, input, grad_output, grad_kernel, workspace_buffer, workspace_size, activation, activation_parameters, threadpool, profile) - @nnpack_check ccall((:nnp_convolution_kernel_gradient, libnnpack), nnp_status, (nnp_convolution_algorithm, Csize_t, Csize_t, Csize_t, nnp_size, nnp_padding, nnp_size, Ptr{Cfloat}, Ptr{Cfloat}, Ptr{Cfloat}, Ptr{Cvoid}, Csize_t, nnp_activation, Ptr{Cvoid}, pthreadpool_t, Ptr{Cvoid}), algorithm, batch_size, input_channels, output_channels, input_size, input_padding, kernel_size, input, grad_output, grad_kernel, workspace_buffer, workspace_size, activation, activation_parameters, threadpool, C_NULL) -end - -function nnp_convolution_kernel_gradient(dw::Array{Float32,4}, x::Array{Float32,4}, dy::Array{Float32,4}; algo::nnp_convolution_algorithm = UInt32(0), workspace_buffer = nothing, workspace_size = 0, padding = 0, stride = 1, threadpool = C_NULL, profile = nothing) - input_size = nnp_size(Csize_t.((size(x,1), size(x,2)))...) - kernel_size = nnp_size(Csize_t.((size(dw,1),size(dw,2)))...) - input_padding = nnp_padding(Csize_t(padding[2]), Csize_t(padding[1]), Csize_t(padding[2]), Csize_t(padding[1])) - profile = profile == nothing ? nnp_profile() : profile - workspace_buffer = workspace_buffer === nothing ? C_NULL : workspace_buffer - nnp_convolution_kernel_gradient(UInt32(algo), size(x,4), size(x,3), size(dw,4), input_size, input_padding, kernel_size, x, dy, dw, workspace_buffer, workspace_size, UInt32(0), C_NULL, threadpool, profile) - dw -end - -function nnp_convolution_output(algorithm, batch_size, input_channels, output_channels, input_size, input_padding, kernel_size, input, kernel, bias, output, workspace_buffer, workspace_size, activation, activation_parameters, threadpool, profile) - @nnpack_check ccall((:nnp_convolution_output, libnnpack), nnp_status, (nnp_convolution_algorithm, Csize_t, Csize_t, Csize_t, nnp_size, nnp_padding, nnp_size, Ptr{Cfloat}, Ptr{Cfloat}, Ptr{Cfloat}, Ptr{Cfloat}, Ptr{Cvoid}, Csize_t, nnp_activation, Ptr{Cvoid}, pthreadpool_t, Ptr{Cvoid}), algorithm, batch_size, input_channels, output_channels, input_size, input_padding, kernel_size, input, kernel, bias, output, workspace_buffer, workspace_size, activation, activation_parameters, threadpool, C_NULL) -end - -function nnp_convolution_output(y::Array{Float32,4}, x::Array{Float32,4}, w::Array{Float32,4}, b::Array{Float32,1}; algo::nnp_convolution_algorithm = UInt32(0), workspace_buffer = nothing, workspace_size = 0, padding = 0, stride = 1, threadpool = C_NULL, profile = nothing) - input_size = nnp_size(Csize_t.((size(x,1), size(x,2)))...) - kernel_size = nnp_size(Csize_t.((size(w,1),size(w,2)))...) - input_padding = nnp_padding(Csize_t(padding[3]), Csize_t(padding[2]), Csize_t(padding[4]), Csize_t(padding[1])) - profile = profile == nothing ? nnp_profile() : profile - workspace_buffer = workspace_buffer === nothing ? C_NULL : workspace_buffer - nnp_convolution_output(UInt32(algo), size(x,4), size(x,3), size(w,4), input_size, input_padding, kernel_size, x, w, b, y, workspace_buffer, workspace_size, UInt32(0), C_NULL, threadpool, profile) - y -end diff --git a/src/nnpack/libnnpack_types.jl b/src/nnpack/libnnpack_types.jl deleted file mode 100644 index 6e7b23c16..000000000 --- a/src/nnpack/libnnpack_types.jl +++ /dev/null @@ -1,85 +0,0 @@ -const nnp_status = UInt32 -const nnp_status_success = (UInt32)(0) -const nnp_status_invalid_batch_size = (UInt32)(2) -const nnp_status_invalid_channels = (UInt32)(3) -const nnp_status_invalid_input_channels = (UInt32)(4) -const nnp_status_invalid_output_channels = (UInt32)(5) -const nnp_status_invalid_input_size = (UInt32)(10) -const nnp_status_invalid_input_stride = (UInt32)(11) -const nnp_status_invalid_input_padding = (UInt32)(12) -const nnp_status_invalid_kernel_size = (UInt32)(13) -const nnp_status_invalid_pooling_size = (UInt32)(14) -const nnp_status_invalid_pooling_stride = (UInt32)(15) -const nnp_status_invalid_algorithm = (UInt32)(16) -const nnp_status_invalid_transform_strategy = (UInt32)(17) -const nnp_status_invalid_output_subsampling = (UInt32)(13) -const nnp_status_invalid_activation = (UInt32)(14) -const nnp_status_invalid_activation_parameters = (UInt32)(15) -const nnp_status_unsupported_input_size = (UInt32)(20) -const nnp_status_unsupported_input_stride = (UInt32)(21) -const nnp_status_unsupported_input_padding = (UInt32)(22) -const nnp_status_unsupported_kernel_size = (UInt32)(23) -const nnp_status_unsupported_pooling_size = (UInt32)(24) -const nnp_status_unsupported_pooling_stride = (UInt32)(25) -const nnp_status_unsupported_algorithm = (UInt32)(26) -const nnp_status_unsupported_transform_strategy = (UInt32)(57) -const nnp_status_unsupported_activation = (UInt32)(28) -const nnp_status_unsupported_activation_parameters = (UInt32)(29) -const nnp_status_uninitialized = (UInt32)(50) -const nnp_status_unsupported_hardware = (UInt32)(51) -const nnp_status_out_of_memory = (UInt32)(52) -const nnp_status_insufficient_buffer = (UInt32)(53) -const nnp_status_misaligned_buffer = (UInt32)(54) - -const nnp_activation = UInt32 -const nnp_activation_identity = (UInt32)(0) -const nnp_activation_relu = (UInt32)(1) - -const nnp_convolution_algorithm = UInt32 -const nnp_convolution_algorithm_auto = (UInt32)(0) -const nnp_convolution_algorithm_ft8x8 = (UInt32)(1) -const nnp_convolution_algorithm_ft16x16 = (UInt32)(2) -const nnp_convolution_algorithm_wt8x8 = (UInt32)(3) -const nnp_convolution_algorithm_implicit_gemm = (UInt32)(4) -const nnp_convolution_algorithm_direct = (UInt32)(5) -const nnp_convolution_algorithm_wt8x8_fp16 = (UInt32)(6) - -const nnp_convolution_transform_strategy = UInt32 -const nnp_convolution_transform_strategy_compute = (UInt32)(1) -const nnp_convolution_transform_strategy_precompute = (UInt32)(2) -const nnp_convolution_transform_strategy_reuse = (UInt32)(3) - -const pthreadpool_t = Ptr{Nothing} - -mutable struct nnp_size - width::Csize_t - height::Csize_t - nnp_size() = new(Csize_t(0), Csize_t(0)) - nnp_size(w, h) = new(Csize_t(w), Csize_t(h)) -end - -Base.unsafe_convert(::Type{Ptr{nnp_size}}, a::nnp_size) = Ptr{a} - -mutable struct nnp_padding - top::Csize_t - right::Csize_t - bottom::Csize_t - left::Csize_t - nnp_padding() = new(Csize_t(0), Csize_t(0), Csize_t(0), Csize_t(0)) - nnp_padding(val) = new(Csize_t(val), Csize_t(val), Csize_t(val), Csize_t(val)) - nnp_padding(t, r, b, l) = new(Csize_t(t), Csize_t(r), Csize_t(b), Csize_t(l)) -end - -Base.unsafe_convert(::Type{Ptr{nnp_padding}}, a::nnp_padding) = Ptr{a} - -mutable struct nnp_profile - total::Cdouble - input_transform::Cdouble - kernel_transform::Cdouble - output_transform::Cdouble - block_multiplication::Cdouble - nnp_profile() = new(Cdouble(0.0), Cdouble(0.0), Cdouble(0.0), Cdouble(0.0), Cdouble(0.0)) - nnp_profile(t, it, kt, ot, bm) = new(Cdouble(t), Cdouble(it), Cdouble(kt), Cdouble(ot), Cdouble(bm)) -end - -Base.unsafe_convert(::Type{Ptr{nnp_profile}}, a::nnp_profile) = Ptr{a} diff --git a/src/nnpack/performance.jl b/src/nnpack/performance.jl deleted file mode 100644 index 24abdb411..000000000 --- a/src/nnpack/performance.jl +++ /dev/null @@ -1,31 +0,0 @@ -function select_threadpool(cdims::DenseConvDims, batch_size::Int) - inp_size = input_size(cdims)[1] - if batch_size >= 32 - return shared_threadpool_dict[Int(NNPACK_CPU_THREADS)][] - elseif batch_size >= 16 && inp_size >= 64 - return shared_threadpool_dict[Int(NNPACK_CPU_THREADS)][] - elseif inp_size <= 32 - return C_NULL - elseif inp_size >= 128 - return shared_threadpool_dict[Int(NNPACK_CPU_THREADS)][] - elseif inp_size * batch_size >= 256 - return shared_threadpool_dict[Int(NNPACK_CPU_THREADS)][] - end - return C_NULL -end - -function select_threadpool(pdims::PoolDims, batch_size::Int) - inp_size = input_size(pdims)[1] - if batch_size >= 32 - return shared_threadpool_dict[Int(NNPACK_CPU_THREADS)][] - elseif batch_size >= 16 && inp_size >= 64 - return shared_threadpool_dict[Int(NNPACK_CPU_THREADS)][] - elseif inp_size <= 32 - return C_NULL - elseif inp_size >= 128 - return shared_threadpool_dict[Int(NNPACK_CPU_THREADS)][] - elseif inp_size * batch_size >= 256 - return shared_threadpool_dict[Int(NNPACK_CPU_THREADS)][] - end - return C_NULL -end diff --git a/src/pooling.jl b/src/pooling.jl index 1cf666f54..6caa1045d 100644 --- a/src/pooling.jl +++ b/src/pooling.jl @@ -107,7 +107,7 @@ end # Finally, let's generate auto-allocating versions of all our functions, for all backends: -for backend in (Symbol(), :_direct, :_nnpack) +for backend in (Symbol(), :_direct) # First make auto-allocating versions of the basic pooling calls: for name in (:maxpool, :meanpool, :lpnormpool) @eval begin @@ -132,16 +132,6 @@ for backend in (Symbol(), :_direct, :_nnpack) end end -## Use NNPACK if it is available and operation is supported. -## The corresponding gradient is not available in NNPACK -## Commented out due to #210 -# if is_nnpack_available() -# function maxpool(x::Array{Float32, 4}, pdims::PoolDims{2, K, S, P, (1, 1)}; kwargs...) where {T, K, S, P} -# func = nnpack_supported_operation(pdims) ? maxpool_nnpack : maxpool_direct -# return func(x, pdims; kwargs...) -# end -# end - expand(N, i::Tuple) = i expand(N, i::Integer) = ntuple(_ -> i, N) diff --git a/test/conv.jl b/test/conv.jl index 492de2cc7..badb2c5f0 100644 --- a/test/conv.jl +++ b/test/conv.jl @@ -277,13 +277,7 @@ ddims(x) = dropdims(x, dims=(ndims(x)-1, ndims(x))) w = reshape(Float64[1:prod(size(dw));], size(dw)..., 1, 1) convs = [NNlib.conv, NNlib.conv_im2col, NNlib.conv_direct,] - NNlib.is_nnpack_available() && push!(convs, NNlib.conv_nnpack) for conv in convs - if NNlib.is_nnpack_available() - if conv == NNlib.conv_nnpack && !NNlib.nnpack_supported_operation(DenseConvDims(x, w)) - continue - end - end @testset "$(conv)" begin cdims = DenseConvDims(x, w) # First, your basic convolution with no parameters @@ -313,13 +307,7 @@ ddims(x) = dropdims(x, dims=(ndims(x)-1, ndims(x))) # Test all in-place implementations/interfaces convs = [NNlib.conv!, NNlib.conv_im2col!, NNlib.conv_direct!,] - NNlib.is_nnpack_available() && push!(convs, NNlib.conv_nnpack!) for conv! in convs - if NNlib.is_nnpack_available() - if conv! == NNlib.conv_nnpack! && !NNlib.nnpack_supported_operation(DenseConvDims(x, w)) - continue - end - end α, β = 2e0, -1e0 @testset "$(conv!)" begin @@ -470,13 +458,7 @@ end w = reshape(complex.(Float64[1:4;] .+ 2, Float64[1:4;] .+ 3), 1, 4, 1) cdims = DenseConvDims(x, w) convs = [NNlib.conv, NNlib.conv_im2col, NNlib.conv_direct,] - NNlib.is_nnpack_available() && push!(convs, NNlib.conv_nnpack) for conv in convs - if NNlib.is_nnpack_available() - if conv == NNlib.conv_nnpack && !NNlib.nnpack_supported_operation(cdims) - continue - end - end @testset "$(conv)" begin @test isapprox(ddims(conv(x, w, cdims)), [transpose(vec(w)) * vec(x)], rtol = 1.0e-7) end diff --git a/test/inference.jl b/test/inference.jl index 31785597c..9b3e74db8 100644 --- a/test/inference.jl +++ b/test/inference.jl @@ -3,9 +3,6 @@ import NNlib: conv_direct, conv_im2col, channels_in, channels_out @testset "Conv Inference" begin for T in (Float32, Float64) impl = [conv, conv_direct, conv_im2col] - if NNlib.is_nnpack_available() && T == Float32 - push!(impl, NNlib.conv_nnpack) - end x = rand(T, 10, 10, 3, 2) w = rand(T, 3, 3, 3, 1) diff --git a/test/pooling.jl b/test/pooling.jl index 97d014d52..f9d57ade7 100644 --- a/test/pooling.jl +++ b/test/pooling.jl @@ -869,16 +869,6 @@ maxpool_answer_nature = Dict( @test y_maxpool_dir ≈ y_maxpool atol = 1e-6 @test isapprox(config.dx_maxpool, NNlib.∇maxpool_direct(dy, y_maxpool_dir, x, pdims), rtol=1e-5) @test isapprox(config.dx_meanpool, NNlib.∇meanpool_direct(dy, y_meanpool_dir, x, pdims), rtol=1e-5) - - # CHECK NNPACK - if NNlib.is_nnpack_available() && T == Float32 - if NNlib.nnpack_supported_operation(pdims) - y_maxpool_nnp = NNlib.maxpool_nnpack(x, pdims) - @test y_maxpool_nnp ≈ y_maxpool atol = 1e-6 - # NNPACK maxpool gradient still missing - # @test isapprox(config.dx_maxpool, NNlib.∇maxpool_nnpack(dy, y_maxpool_nnp, config.x, pdims), rtol=1e-5) - end - end end for (rank_name, config_dict) in maxpool_answer_nature @@ -940,16 +930,6 @@ maxpool_answer_nature = Dict( @test RD.gradient(_x -> only(maxpool(_x,(2,2))), x)[:,:,1,1] == [0 0; 0 1] @test only(meanpool(x, (2,2))) == 2.5 @test all(==(0.25), RD.gradient(_x -> only(meanpool(_x,(2,2))), x)) - - # if NNlib.is_nnpack_available() - # if NNlib.nnpack_supported_operation(pdims1) - # @test NNlib.maxpool_nnpack(x, pdims1) isa Array{Float32, 4} - # end - # if NNlib.nnpack_supported_operation(pdims2) - # print("you should not see this") - # @test NNlib.maxpool_nnpack(x, pdims2) isa Array{Float32, 4} - # end - # end end @testset "AutoDiff: spatial_rank=$spatial_rank" for spatial_rank in (1, 2)