diff --git a/benchmark/suites/leastsquares.jl b/benchmark/suites/leastsquares.jl index bb8471c..16dda4f 100644 --- a/benchmark/suites/leastsquares.jl +++ b/benchmark/suites/leastsquares.jl @@ -34,5 +34,4 @@ for sz in [(15, 20, 25, 30, 35), (30, 30, 30, 30, 30)], r in [1, 10, 50] @benchmarkable gcp($X, $r, loss = GCPLosses.LeastSquaresLoss()) end - -end \ No newline at end of file +end diff --git a/benchmark/suites/mttkrps.jl b/benchmark/suites/mttkrps.jl index f5b5dc9..3d2129c 100644 --- a/benchmark/suites/mttkrps.jl +++ b/benchmark/suites/mttkrps.jl @@ -12,7 +12,7 @@ const SETUPS = [] append!( SETUPS, [ - (; modes=3, size = sz, rank = r) for + (; modes = 3, size = sz, rank = r) for sz in [ntuple(n -> In, 3) for In in [20, 50, 100, 200, 400]], r in [10, 100] ], ) @@ -21,7 +21,7 @@ append!( append!( SETUPS, [ - (; modes=4, size = sz, rank = r) for + (; modes = 4, size = sz, rank = r) for sz in [ntuple(n -> In, 4) for In in [20, 50, 100]], r in [10, 100] ], ) @@ -30,12 +30,11 @@ append!( append!( SETUPS, [ - (; modes=5, size = sz, rank = r) for + (; modes = 5, size = sz, rank = r) for sz in [ntuple(n -> In, 5) for In in [10, 30, 60]], r in [10, 100] ], ) - # Generate random benchmarks for SETUP in SETUPS @@ -53,15 +52,8 @@ for SETUP in SETUPS end λ, U = M0.λ, collect(M0.U) - SUITE["modes=$(SETUP.modes), size=$(SETUP.size), rank=$(SETUP.rank)"] = @benchmarkable( - GCPDecompositions.mttkrps!($X, $U, $λ), - seconds = 5, - samples = 5, - ) + SUITE["modes=$(SETUP.modes), size=$(SETUP.size), rank=$(SETUP.rank)"] = + @benchmarkable(GCPDecompositions.mttkrps!($X, $U, $λ), seconds = 5, samples = 5,) end - - - - end diff --git a/src/gcp-algorithms/fastals.jl b/src/gcp-algorithms/fastals.jl index e3ce4a4..fb2d036 100644 --- a/src/gcp-algorithms/fastals.jl +++ b/src/gcp-algorithms/fastals.jl @@ -53,7 +53,6 @@ end section III-C. """ function FastALS_iter!(X, M, order, Jns, Kns, buffers) - N = ndims(X) R = size(M.U[1])[2] @@ -63,11 +62,25 @@ function FastALS_iter!(X, M, order, Jns, Kns, buffers) if n == n_star if n == 1 khatrirao!(buffers.kr_buffer_descending, M.U[reverse(n+1:N)]...) - mul!(M.U[n], reshape(X, (Jns[n], Kns[n])), buffers.kr_buffer_descending) + mul!(M.U[n], reshape(X, (Jns[n], Kns[n])), buffers.kr_buffer_descending) else khatrirao!(buffers.kr_buffer_descending, M.U[reverse(n+1:N)]...) - mul!(buffers.descending_buffers[1], reshape(X, (Jns[n], Kns[n])), buffers.kr_buffer_descending) - FastALS_mttkrps_helper!(buffers.descending_buffers[1], M.U, n_star, n, "right", N, Jns, Kns, buffers) + mul!( + buffers.descending_buffers[1], + reshape(X, (Jns[n], Kns[n])), + buffers.kr_buffer_descending, + ) + FastALS_mttkrps_helper!( + buffers.descending_buffers[1], + M.U, + n_star, + n, + "right", + N, + Jns, + Kns, + buffers, + ) end elseif n == n_star + 1 if n == N @@ -75,30 +88,92 @@ function FastALS_iter!(X, M, order, Jns, Kns, buffers) mul!(M.U[n], reshape(X, (Jns[n-1], Kns[n-1]))', buffers.kr_buffer_ascending) else khatrirao!(buffers.kr_buffer_ascending, M.U[reverse(1:n-1)]...) - mul!(buffers.ascending_buffers[1], (reshape(X, (Jns[n-1], Kns[n-1])))', buffers.kr_buffer_ascending) - FastALS_mttkrps_helper!(buffers.ascending_buffers[1], M.U, n_star, n, "left", N, Jns, Kns, buffers) - end + mul!( + buffers.ascending_buffers[1], + (reshape(X, (Jns[n-1], Kns[n-1])))', + buffers.kr_buffer_ascending, + ) + FastALS_mttkrps_helper!( + buffers.ascending_buffers[1], + M.U, + n_star, + n, + "left", + N, + Jns, + Kns, + buffers, + ) + end elseif n < n_star if n == 1 for r in 1:R - mul!(view(M.U[n], :, r), reshape(view(buffers.descending_buffers[n_star-n], :, r), (Jns[n], size(X)[n+1])), view(M.U[n+1], :, r)) + mul!( + view(M.U[n], :, r), + reshape( + view(buffers.descending_buffers[n_star-n], :, r), + (Jns[n], size(X)[n+1]), + ), + view(M.U[n+1], :, r), + ) end else for r in 1:R - mul!(view(buffers.descending_buffers[n_star-n+1], :, r), reshape(view(buffers.descending_buffers[n_star-n], :, r), (Jns[n], size(X)[n+1])), view(M.U[n+1], :, r)) + mul!( + view(buffers.descending_buffers[n_star-n+1], :, r), + reshape( + view(buffers.descending_buffers[n_star-n], :, r), + (Jns[n], size(X)[n+1]), + ), + view(M.U[n+1], :, r), + ) end - FastALS_mttkrps_helper!(buffers.descending_buffers[n_star-n+1], M.U, n_star, n, "right", N, Jns, Kns, buffers) + FastALS_mttkrps_helper!( + buffers.descending_buffers[n_star-n+1], + M.U, + n_star, + n, + "right", + N, + Jns, + Kns, + buffers, + ) end else if n == N for r in 1:R - mul!(view(M.U[n], :, r), reshape(view(buffers.ascending_buffers[N-n_star-1], :, r), (size(X)[n-1], Kns[n-1]))', view(M.U[n-1], :, r)) + mul!( + view(M.U[n], :, r), + reshape( + view(buffers.ascending_buffers[N-n_star-1], :, r), + (size(X)[n-1], Kns[n-1]), + )', + view(M.U[n-1], :, r), + ) end else for r in 1:R - mul!(view(buffers.ascending_buffers[n-n_star], :, r), reshape(view(buffers.ascending_buffers[n-n_star-1], :, r), (size(X)[n-1], Kns[n-1]))', view(M.U[n-1], :, r)) + mul!( + view(buffers.ascending_buffers[n-n_star], :, r), + reshape( + view(buffers.ascending_buffers[n-n_star-1], :, r), + (size(X)[n-1], Kns[n-1]), + )', + view(M.U[n-1], :, r), + ) end - FastALS_mttkrps_helper!(buffers.ascending_buffers[n-n_star], M.U, n_star, n, "left", N, Jns, Kns, buffers) + FastALS_mttkrps_helper!( + buffers.ascending_buffers[n-n_star], + M.U, + n_star, + n, + "left", + N, + Jns, + Kns, + buffers, + ) end end # Normalization, update weights @@ -107,19 +182,26 @@ function FastALS_iter!(X, M, order, Jns, Kns, buffers) M.λ .= norm.(eachcol(M.U[n])) M.U[n] ./= permutedims(M.λ) end -end - +end function FastALS_mttkrps_helper!(Zn, U, n_star, n, side, N, Jns, Kns, buffers) if side == "right" khatrirao!(buffers.helper_buffers_descending[n_star-n+1], U[reverse(1:n-1)]...) for r in 1:size(U[n])[2] - mul!(view(U[n], :, r), reshape(view(Zn, :, r), (Jns[n-1], size(U[n])[1]))', view(buffers.helper_buffers_descending[n_star-n+1], :, r)) + mul!( + view(U[n], :, r), + reshape(view(Zn, :, r), (Jns[n-1], size(U[n])[1]))', + view(buffers.helper_buffers_descending[n_star-n+1], :, r), + ) end elseif side == "left" khatrirao!(buffers.helper_buffers_ascending[n-n_star], U[reverse(n+1:N)]...) for r in 1:size(U[n])[2] - mul!(view(U[n], :, r), reshape(view(Zn, :, r), (size(U[n])[1], Kns[n])), view(buffers.helper_buffers_ascending[n-n_star], :, r)) + mul!( + view(U[n], :, r), + reshape(view(Zn, :, r), (size(U[n])[1], Kns[n])), + view(buffers.helper_buffers_ascending[n-n_star], :, r), + ) end end end @@ -127,25 +209,34 @@ end function create_FastALS_buffers( U::NTuple{N,TM}, order, - Jns, + Jns, Kns, ) where {TM<:AbstractMatrix,N} - n_star = order[1] r = size(U[1])[2] dims = [size(U[u])[1] for u in 1:length(U)] # Allocate buffers # Buffer for saved products between modes - descending_buffers = n_star < 2 ? nothing : [similar(U[1], (Jns[n], r)) for n in n_star:-1:2] - ascending_buffers = N - n_star - 1 < 1 ? nothing : [similar(U[1], (Kns[n], r)) for n in n_star:N] + descending_buffers = + n_star < 2 ? nothing : [similar(U[1], (Jns[n], r)) for n in n_star:-1:2] + ascending_buffers = + N - n_star - 1 < 1 ? nothing : [similar(U[1], (Kns[n], r)) for n in n_star:N] # Buffers for khatri-rao products kr_buffer_descending = similar(U[1], (Kns[n_star], r)) kr_buffer_ascending = similar(U[1], (Jns[n_star], r)) # Buffers for khatri-rao product in helper function - helper_buffers_descending = n_star < 2 ? nothing : [similar(U[1], (prod(dims[1:n-1]), r)) for n in n_star:-1:2] - helper_buffers_ascending = n_star >= N-1 ? nothing : [similar(U[1], (prod(dims[n+1:N]), r)) for n in n_star+1:N-1] - return(; descending_buffers, ascending_buffers, - kr_buffer_descending, kr_buffer_ascending, - helper_buffers_descending, helper_buffers_ascending) + helper_buffers_descending = + n_star < 2 ? nothing : [similar(U[1], (prod(dims[1:n-1]), r)) for n in n_star:-1:2] + helper_buffers_ascending = + n_star >= N - 1 ? nothing : + [similar(U[1], (prod(dims[n+1:N]), r)) for n in n_star+1:N-1] + return (; + descending_buffers, + ascending_buffers, + kr_buffer_descending, + kr_buffer_ascending, + helper_buffers_descending, + helper_buffers_ascending, + ) end