Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster Khatri-Rao #41

Merged
merged 10 commits into from
Mar 1, 2024
1 change: 1 addition & 0 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ const SUITE_MODULES = Dict(
"gcp" => :BenchmarkGCP,
"mttkrp" => :BenchmarkMTTKRP,
"mttkrp-large" => :BenchmarkMTTKRPLarge,
"khatrirao" => :BenchmarkKhatriRao,
)

# Create top-level suite including only sub-suites
Expand Down
73 changes: 73 additions & 0 deletions benchmark/suites/khatrirao.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
module BenchmarkKhatriRao

using BenchmarkTools, GCPDecompositions
using Random

const SUITE = BenchmarkGroup()

# Collect setups
const SETUPS = []

## N=1 matrix
append!(
SETUPS,
[
(; size = sz, rank = r) for sz in [ntuple(n -> In, 1) for In in 30:30:90],
r in [5; 30:30:90]
],
)

## N=2 matrices (balanced)
append!(
SETUPS,
[
(; size = sz, rank = r) for sz in [ntuple(n -> In, 2) for In in 30:30:90],
r in [5; 30:30:90]
],
)

## N=3 matrices (balanced)
append!(
SETUPS,
[
(; size = sz, rank = r) for sz in [ntuple(n -> In, 3) for In in 30:30:90],
r in [5; 30:30:90]
],
)

## N=3 matrices (imbalanced)
append!(
SETUPS,
[
(; size = sz, rank = r) for
sz in [Tuple(circshift([30, 100, 1000], c)) for c in 0:2], r in [5; 30:30:90]
],
)

## N=4 matrices (balanced)
append!(
SETUPS,
[
(; size = sz, rank = r) for sz in [ntuple(n -> In, 4) for In in 30:30:90],
r in [5; 30:30:90]
],
)

## N=4 matrices (imbalanced)
append!(
SETUPS,
[
(; size = sz, rank = r) for
sz in [Tuple(circshift([20, 40, 80, 500], c)) for c in 0:3], r in [5; 30:30:90]
],
)

# Generate random benchmarks
for SETUP in SETUPS
Random.seed!(0)
U = [randn(In, SETUP.rank) for In in SETUP.size]
SUITE["size=$(SETUP.size), rank=$(SETUP.rank)"] =
@benchmarkable(GCPDecompositions.khatrirao($U...), seconds = 2, samples = 5,)
end

end
15 changes: 9 additions & 6 deletions src/gcp-opt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -248,11 +248,14 @@ function khatrirao(A::Vararg{T,N}) where {T<:AbstractMatrix,N}
return A[1]
end

# General case: N > 1
r = (only ∘ unique)(size.(A, 2))
K = similar(A[1], prod(size.(A, 1)), r)
for j in 1:r
K[:, j] = reduce(kron, [view(A[i], :, j) for i in 1:N])
# Base case: N = 2
if N == 2
r = (only ∘ unique)(size.(A, 2))
return reshape(reshape(A[2], :, 1, r) .* reshape(A[1], 1, :, r), :, r)
end
return K

# Recursive case: N > 2
I, r = size.(A, 1), (only ∘ unique)(size.(A, 2))
n = argmin(n -> I[n] * I[n+1], 1:N-1)
return khatrirao(A[1:n-1]..., khatrirao(A[n], A[n+1]), A[n+2:end]...)
end
17 changes: 17 additions & 0 deletions test/items/gcp-opt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,23 @@ end
Mh = gcp(X, r) # test default (least-squares) loss
@test maximum(I -> abs(Mh[I] - X[I]), CartesianIndices(X)) <= 1e-5
end

# 4-way tensor to exercise recursive part of the Khatri-Rao code
@testset "size(X)=$sz, rank(X)=$r" for sz in [(50, 40, 30, 2)], r in 1:2
Random.seed!(0)
M = CPD(ones(r), rand.(sz, r))
X = [M[I] for I in CartesianIndices(size(M))]
Mh = gcp(X, r, LeastSquaresLoss())
@test maximum(I -> abs(Mh[I] - X[I]), CartesianIndices(X)) <= 1e-5

Xm = convert(Array{Union{Missing,eltype(X)}}, X)
Xm[1, 1, 1, 1] = missing
Mm = gcp(Xm, r, LeastSquaresLoss())
@test maximum(I -> abs(Mm[I] - X[I]), CartesianIndices(X)) <= 1e-5

Mh = gcp(X, r) # test default (least-squares) loss
@test maximum(I -> abs(Mh[I] - X[I]), CartesianIndices(X)) <= 1e-5
end
end

@testitem "NonnegativeLeastSquaresLoss" begin
Expand Down
Loading