Skip to content

Commit

Permalink
Run JuliaFormatter
Browse files Browse the repository at this point in the history
  • Loading branch information
dahong67 committed Mar 15, 2024
1 parent ada19ee commit 0344998
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 41 deletions.
3 changes: 1 addition & 2 deletions benchmark/suites/leastsquares.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,4 @@ for sz in [(15, 20, 25, 30, 35), (30, 30, 30, 30, 30)], r in [1, 10, 50]
@benchmarkable gcp($X, $r, loss = GCPLosses.LeastSquaresLoss())
end


end
end
18 changes: 5 additions & 13 deletions benchmark/suites/mttkrps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ const SETUPS = []
append!(
SETUPS,
[
(; modes=3, size = sz, rank = r) for
(; modes = 3, size = sz, rank = r) for
sz in [ntuple(n -> In, 3) for In in [20, 50, 100, 200, 400]], r in [10, 100]
],
)
Expand All @@ -21,7 +21,7 @@ append!(
append!(
SETUPS,
[
(; modes=4, size = sz, rank = r) for
(; modes = 4, size = sz, rank = r) for
sz in [ntuple(n -> In, 4) for In in [20, 50, 100]], r in [10, 100]
],
)
Expand All @@ -30,12 +30,11 @@ append!(
append!(
SETUPS,
[
(; modes=5, size = sz, rank = r) for
(; modes = 5, size = sz, rank = r) for
sz in [ntuple(n -> In, 5) for In in [10, 30, 60]], r in [10, 100]
],
)


# Generate random benchmarks
for SETUP in SETUPS

Expand All @@ -53,15 +52,8 @@ for SETUP in SETUPS
end
λ, U = M0.λ, collect(M0.U)

SUITE["modes=$(SETUP.modes), size=$(SETUP.size), rank=$(SETUP.rank)"] = @benchmarkable(
GCPDecompositions.mttkrps!($X, $U, $λ),
seconds = 5,
samples = 5,
)
SUITE["modes=$(SETUP.modes), size=$(SETUP.size), rank=$(SETUP.rank)"] =
@benchmarkable(GCPDecompositions.mttkrps!($X, $U, $λ), seconds = 5, samples = 5,)
end





end
143 changes: 117 additions & 26 deletions src/gcp-algorithms/fastals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ end
section III-C.
"""
function FastALS_iter!(X, M, order, Jns, Kns, buffers)

N = ndims(X)
R = size(M.U[1])[2]

Expand All @@ -63,42 +62,118 @@ function FastALS_iter!(X, M, order, Jns, Kns, buffers)
if n == n_star
if n == 1
khatrirao!(buffers.kr_buffer_descending, M.U[reverse(n+1:N)]...)
mul!(M.U[n], reshape(X, (Jns[n], Kns[n])), buffers.kr_buffer_descending)
mul!(M.U[n], reshape(X, (Jns[n], Kns[n])), buffers.kr_buffer_descending)
else
khatrirao!(buffers.kr_buffer_descending, M.U[reverse(n+1:N)]...)
mul!(buffers.descending_buffers[1], reshape(X, (Jns[n], Kns[n])), buffers.kr_buffer_descending)
FastALS_mttkrps_helper!(buffers.descending_buffers[1], M.U, n_star, n, "right", N, Jns, Kns, buffers)
mul!(
buffers.descending_buffers[1],
reshape(X, (Jns[n], Kns[n])),
buffers.kr_buffer_descending,
)
FastALS_mttkrps_helper!(
buffers.descending_buffers[1],
M.U,
n_star,
n,
"right",
N,
Jns,
Kns,
buffers,
)
end
elseif n == n_star + 1
if n == N
khatrirao!(buffers.kr_buffer_ascending, M.U[reverse(1:n-1)]...)
mul!(M.U[n], reshape(X, (Jns[n-1], Kns[n-1]))', buffers.kr_buffer_ascending)
else
khatrirao!(buffers.kr_buffer_ascending, M.U[reverse(1:n-1)]...)
mul!(buffers.ascending_buffers[1], (reshape(X, (Jns[n-1], Kns[n-1])))', buffers.kr_buffer_ascending)
FastALS_mttkrps_helper!(buffers.ascending_buffers[1], M.U, n_star, n, "left", N, Jns, Kns, buffers)
end
mul!(
buffers.ascending_buffers[1],
(reshape(X, (Jns[n-1], Kns[n-1])))',
buffers.kr_buffer_ascending,
)
FastALS_mttkrps_helper!(
buffers.ascending_buffers[1],
M.U,
n_star,
n,
"left",
N,
Jns,
Kns,
buffers,
)
end
elseif n < n_star
if n == 1
for r in 1:R
mul!(view(M.U[n], :, r), reshape(view(buffers.descending_buffers[n_star-n], :, r), (Jns[n], size(X)[n+1])), view(M.U[n+1], :, r))
mul!(
view(M.U[n], :, r),
reshape(
view(buffers.descending_buffers[n_star-n], :, r),
(Jns[n], size(X)[n+1]),
),
view(M.U[n+1], :, r),
)
end
else
for r in 1:R
mul!(view(buffers.descending_buffers[n_star-n+1], :, r), reshape(view(buffers.descending_buffers[n_star-n], :, r), (Jns[n], size(X)[n+1])), view(M.U[n+1], :, r))
mul!(

Check warning on line 122 in src/gcp-algorithms/fastals.jl

View check run for this annotation

Codecov / codecov/patch

src/gcp-algorithms/fastals.jl#L122

Added line #L122 was not covered by tests
view(buffers.descending_buffers[n_star-n+1], :, r),
reshape(
view(buffers.descending_buffers[n_star-n], :, r),
(Jns[n], size(X)[n+1]),
),
view(M.U[n+1], :, r),
)
end

Check warning on line 130 in src/gcp-algorithms/fastals.jl

View check run for this annotation

Codecov / codecov/patch

src/gcp-algorithms/fastals.jl#L130

Added line #L130 was not covered by tests
FastALS_mttkrps_helper!(buffers.descending_buffers[n_star-n+1], M.U, n_star, n, "right", N, Jns, Kns, buffers)
FastALS_mttkrps_helper!(
buffers.descending_buffers[n_star-n+1],
M.U,
n_star,
n,
"right",
N,
Jns,
Kns,
buffers,
)
end
else
if n == N
for r in 1:R
mul!(view(M.U[n], :, r), reshape(view(buffers.ascending_buffers[N-n_star-1], :, r), (size(X)[n-1], Kns[n-1]))', view(M.U[n-1], :, r))
mul!(
view(M.U[n], :, r),
reshape(
view(buffers.ascending_buffers[N-n_star-1], :, r),
(size(X)[n-1], Kns[n-1]),
)',
view(M.U[n-1], :, r),
)
end
else
for r in 1:R
mul!(view(buffers.ascending_buffers[n-n_star], :, r), reshape(view(buffers.ascending_buffers[n-n_star-1], :, r), (size(X)[n-1], Kns[n-1]))', view(M.U[n-1], :, r))
mul!(

Check warning on line 157 in src/gcp-algorithms/fastals.jl

View check run for this annotation

Codecov / codecov/patch

src/gcp-algorithms/fastals.jl#L156-L157

Added lines #L156 - L157 were not covered by tests
view(buffers.ascending_buffers[n-n_star], :, r),
reshape(
view(buffers.ascending_buffers[n-n_star-1], :, r),
(size(X)[n-1], Kns[n-1]),
)',
view(M.U[n-1], :, r),
)
end
FastALS_mttkrps_helper!(buffers.ascending_buffers[n-n_star], M.U, n_star, n, "left", N, Jns, Kns, buffers)
FastALS_mttkrps_helper!(

Check warning on line 166 in src/gcp-algorithms/fastals.jl

View check run for this annotation

Codecov / codecov/patch

src/gcp-algorithms/fastals.jl#L165-L166

Added lines #L165 - L166 were not covered by tests
buffers.ascending_buffers[n-n_star],
M.U,
n_star,
n,
"left",
N,
Jns,
Kns,
buffers,
)
end
end
# Normalization, update weights
Expand All @@ -107,45 +182,61 @@ function FastALS_iter!(X, M, order, Jns, Kns, buffers)
M.λ .= norm.(eachcol(M.U[n]))
M.U[n] ./= permutedims(M.λ)
end
end

end

function FastALS_mttkrps_helper!(Zn, U, n_star, n, side, N, Jns, Kns, buffers)
if side == "right"
khatrirao!(buffers.helper_buffers_descending[n_star-n+1], U[reverse(1:n-1)]...)
for r in 1:size(U[n])[2]
mul!(view(U[n], :, r), reshape(view(Zn, :, r), (Jns[n-1], size(U[n])[1]))', view(buffers.helper_buffers_descending[n_star-n+1], :, r))
mul!(
view(U[n], :, r),
reshape(view(Zn, :, r), (Jns[n-1], size(U[n])[1]))',
view(buffers.helper_buffers_descending[n_star-n+1], :, r),
)
end
elseif side == "left"
khatrirao!(buffers.helper_buffers_ascending[n-n_star], U[reverse(n+1:N)]...)
for r in 1:size(U[n])[2]
mul!(view(U[n], :, r), reshape(view(Zn, :, r), (size(U[n])[1], Kns[n])), view(buffers.helper_buffers_ascending[n-n_star], :, r))
mul!(
view(U[n], :, r),
reshape(view(Zn, :, r), (size(U[n])[1], Kns[n])),
view(buffers.helper_buffers_ascending[n-n_star], :, r),
)
end
end
end

function create_FastALS_buffers(
U::NTuple{N,TM},
order,
Jns,
Jns,
Kns,
) where {TM<:AbstractMatrix,N}

n_star = order[1]
r = size(U[1])[2]
dims = [size(U[u])[1] for u in 1:length(U)]

# Allocate buffers
# Buffer for saved products between modes
descending_buffers = n_star < 2 ? nothing : [similar(U[1], (Jns[n], r)) for n in n_star:-1:2]
ascending_buffers = N - n_star - 1 < 1 ? nothing : [similar(U[1], (Kns[n], r)) for n in n_star:N]
descending_buffers =
n_star < 2 ? nothing : [similar(U[1], (Jns[n], r)) for n in n_star:-1:2]
ascending_buffers =
N - n_star - 1 < 1 ? nothing : [similar(U[1], (Kns[n], r)) for n in n_star:N]
# Buffers for khatri-rao products
kr_buffer_descending = similar(U[1], (Kns[n_star], r))
kr_buffer_ascending = similar(U[1], (Jns[n_star], r))
# Buffers for khatri-rao product in helper function
helper_buffers_descending = n_star < 2 ? nothing : [similar(U[1], (prod(dims[1:n-1]), r)) for n in n_star:-1:2]
helper_buffers_ascending = n_star >= N-1 ? nothing : [similar(U[1], (prod(dims[n+1:N]), r)) for n in n_star+1:N-1]
return(; descending_buffers, ascending_buffers,
kr_buffer_descending, kr_buffer_ascending,
helper_buffers_descending, helper_buffers_ascending)
helper_buffers_descending =
n_star < 2 ? nothing : [similar(U[1], (prod(dims[1:n-1]), r)) for n in n_star:-1:2]
helper_buffers_ascending =
n_star >= N - 1 ? nothing :
[similar(U[1], (prod(dims[n+1:N]), r)) for n in n_star+1:N-1]
return (;
descending_buffers,
ascending_buffers,
kr_buffer_descending,
kr_buffer_ascending,
helper_buffers_descending,
helper_buffers_ascending,
)
end

0 comments on commit 0344998

Please sign in to comment.