Skip to content

Commit

Permalink
layer MultiHeadAttention, and show methods for this
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Mar 2, 2024
1 parent 35ade0a commit 29e0d68
Showing 1 changed file with 42 additions and 3 deletions.
45 changes: 42 additions & 3 deletions src/layers/attention.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ struct MultiHeadAttention{P1, D, P2}
out_proj::P2
end

@functor MultiHeadAttention
@layer MultiHeadAttention

function MultiHeadAttention(dims;
nheads::Int = 8,
Expand All @@ -83,8 +83,8 @@ function MultiHeadAttention(dims;
dropout_prob = 0.0)

dims = normalize_mha_dims(dims)
@assert dims.qk % nheads == 0 "qk_dim should be divisible by nheads"
@assert dims.v % nheads == 0 "v_dim should be divisible by nheads"
dims.qk % nheads == 0 || throw(ArgumentError("qk_dim = $(dims.qk) should be divisible by nheads = $(nheads)")
dims.v % nheads == 0 || throw(ArgumentError( "v_dim = $(dims.v) should be divisible by nheads = $(nheads)")
q_proj = Dense(dims.q_in => dims.qk; bias, init)
k_proj = Dense(dims.k_in => dims.qk; bias, init)
v_proj = Dense(dims.v_in => dims.v; bias, init)
Expand Down Expand Up @@ -131,3 +131,42 @@ function (mha::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3,
# [α] = [kv_len, q_len, nheads, batch_size]
return x, α
end

function Base.show(io::IO, mha::MultiHeadAttention)
qk, q_in = size(mha.q_proj.weight)
qk, k_in = size(mha.k_proj.weight)
v, v_in = size(mha.v_proj.weight)
out, v = size(mha.out_proj.weight)
# @show q_in, k_in, v_in, qk, v, out
print(io, "MultiHeadAttention(")
if q_in == k_in == v_in == qk == v == out
print(io, q_in)
elseif q_in == k_in == v_in && qk == v
print(io, q_in, " => ", qk, " => ", out)
elseif q_in == k_in == v_in
print(io, q_in, " => (", qk, ", ", v,") => ", out)
else
print(io, "(", q_in, ", ", k_in, ", ", v_in, ") => (", qk, ", ", v,") => ", out)
end
print(io, "; nheads=", mha.nheads)
if mha.q_proj.bias === true
print(io, ", bias=true")
end
if mha.attn_drop.p != 0
print(io, ", dropout_prob=", mha.attn_drop.p) # can't we rename this?
end
print(io, ")")
end

Base.show(io::IO, ::MIME"text/plain", mha::MultiHeadAttention) = show(io, mha)

#=
# Test cases:
MultiHeadAttention((3, 4, 5) => (6, 7) => 8; nheads=1)
MultiHeadAttention(3 => (6, 7) => 8; nheads=1)
MultiHeadAttention(3 => 6 => 8; nheads=1)
MultiHeadAttention(8; bias=true)
=#

0 comments on commit 29e0d68

Please sign in to comment.