From 941d9437d95040e408b0082550894d5f25adbdff Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Thu, 5 Dec 2024 12:32:44 +0000 Subject: [PATCH] Remove Redundant gemm Code (#410) * Remove redundant code * Bump patch version * Tidy up gemv perf --- Project.toml | 2 +- src/rrules/blas.jl | 99 ++++------------------------------------------ 2 files changed, 9 insertions(+), 92 deletions(-) diff --git a/Project.toml b/Project.toml index 63df80fb0..654658531 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Mooncake" uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" authors = ["Will Tebbutt, Hong Ge, and contributors"] -version = "0.4.56" +version = "0.4.57" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/rrules/blas.jl b/src/rrules/blas.jl index 02f9f153b..2c2fd5d95 100644 --- a/src/rrules/blas.jl +++ b/src/rrules/blas.jl @@ -25,6 +25,13 @@ const MatrixOrView{T} = Union{Matrix{T},SubArray{T,2,<:Array{T}}} const VecOrView{T} = Union{Vector{T},SubArray{T,1,<:Array{T}}} const BlasRealFloat = Union{Float32,Float64} +viewify(A::CoDual{<:Vector}) = primal(A), tangent(A) +viewify(A::CoDual{<:Matrix}) = view(primal(A), :, :), view(tangent(A), :, :) +function viewify(A::CoDual{P}) where {P<:SubArray} + p_A = primal(A) + return p_A, P(tangent(A).data.parent, p_A.indices, p_A.offset1, p_A.stride1) +end + # # Utility # @@ -393,96 +400,6 @@ function rrule!!( return C, gemm!_pb!! end -viewify(A::CoDual{<:Vector}) = primal(A), tangent(A) -viewify(A::CoDual{<:Matrix}) = view(primal(A), :, :), view(tangent(A), :, :) -function viewify(A::CoDual{P}) where {P<:SubArray} - p_A = primal(A) - return p_A, P(tangent(A).data.parent, p_A.indices, p_A.offset1, p_A.stride1) -end - -for (gemm, elty) in ((:dgemm_, :Float64), (:sgemm_, :Float32)) - @eval function rrule!!( - ::CoDual{typeof(_foreigncall_)}, - ::CoDual{Val{$(blas_name(gemm))}}, - ::CoDual{Val{Cvoid}}, - ::CoDual, # arg types - ::CoDual, # nreq - ::CoDual, # calling convention - tA::CoDual{Ptr{UInt8}}, - tB::CoDual{Ptr{UInt8}}, - m::CoDual{Ptr{BLAS.BlasInt}}, - n::CoDual{Ptr{BLAS.BlasInt}}, - ka::CoDual{Ptr{BLAS.BlasInt}}, - alpha::CoDual{Ptr{$elty}}, - A::CoDual{Ptr{$elty}}, - LDA::CoDual{Ptr{BLAS.BlasInt}}, - B::CoDual{Ptr{$elty}}, - LDB::CoDual{Ptr{BLAS.BlasInt}}, - beta::CoDual{Ptr{$elty}}, - C::CoDual{Ptr{$elty}}, - LDC::CoDual{Ptr{BLAS.BlasInt}}, - args::Vararg{Any,Nargs}, - ) where {Nargs} - GC.@preserve args begin - _tA = Char(unsafe_load(primal(tA))) - _tB = Char(unsafe_load(primal(tB))) - _m = unsafe_load(primal(m)) - _n = unsafe_load(primal(n)) - _ka = unsafe_load(primal(ka)) - _alpha = unsafe_load(primal(alpha)) - _A = primal(A) - _LDA = unsafe_load(primal(LDA)) - _B = primal(B) - _LDB = unsafe_load(primal(LDB)) - _beta = unsafe_load(primal(beta)) - _C = primal(C) - _LDC = unsafe_load(primal(LDC)) - - A_mat = wrap_ptr_as_view( - primal(A), _LDA, (_tA == 'N' ? (_m, _ka) : (_ka, _m))... - ) - B_mat = wrap_ptr_as_view( - primal(B), _LDB, (_tB == 'N' ? (_ka, _n) : (_n, _ka))... - ) - C_mat = wrap_ptr_as_view(primal(C), _LDC, _m, _n) - C_copy = collect(C_mat) - - BLAS.gemm!(_tA, _tB, _alpha, A_mat, B_mat, _beta, C_mat) - - dalpha = tangent(alpha) - dA = tangent(A) - dB = tangent(B) - dbeta = tangent(beta) - dC = tangent(C) - end - - function gemm!_pullback!!(::NoRData) - GC.@preserve args begin - # Restore previous state. - C_mat .= C_copy - - # Convert pointers to views. - dA_mat = wrap_ptr_as_view(dA, _LDA, (_tA == 'N' ? (_m, _ka) : (_ka, _m))...) - dB_mat = wrap_ptr_as_view(dB, _LDB, (_tB == 'N' ? (_ka, _n) : (_n, _ka))...) - dC_mat = wrap_ptr_as_view(dC, _LDC, _m, _n) - - # Increment cotangents. - unsafe_store!(dbeta, unsafe_load(dbeta) + tr(dC_mat' * C_mat)) - dalpha_inc = tr(dC_mat' * _trans(_tA, A_mat) * _trans(_tB, B_mat)) - unsafe_store!(dalpha, unsafe_load(dalpha) + dalpha_inc) - dA_mat .+= - _alpha * transpose(_trans(_tA, _trans(_tB, B_mat) * transpose(dC_mat))) - dB_mat .+= - _alpha * transpose(_trans(_tB, transpose(dC_mat) * _trans(_tA, A_mat))) - dC_mat .*= _beta - end - - return tuple_fill(NoRData(), Val(19 + Nargs)) - end - return zero_fcodual(Cvoid()), gemm!_pullback!! - end -end - @is_primitive( MinimalCtx, Tuple{ @@ -828,7 +745,7 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:blas}) As = blas_matrices(rng, Float64, t ? M : N, t ? N : M) xs = blas_vectors(rng, Float64, N) ys = blas_vectors(rng, Float64, M) - flags = (false, :stability, (lb=1e-3, ub=5.0)) + flags = (false, :stability, (lb=1e-3, ub=10.0)) return map(product(As, xs, ys)) do (A, x, y) return (flags..., BLAS.gemv!, tA, randn(), A, x, randn(), y) end