diff --git a/src/layers/attention.jl b/src/layers/attention.jl index 3701be2bb0..e058088156 100644 --- a/src/layers/attention.jl +++ b/src/layers/attention.jl @@ -74,7 +74,7 @@ struct MultiHeadAttention{P1, D, P2} out_proj::P2 end -@functor MultiHeadAttention +@layer MultiHeadAttention function MultiHeadAttention(dims; nheads::Int = 8, @@ -83,8 +83,8 @@ function MultiHeadAttention(dims; dropout_prob = 0.0) dims = normalize_mha_dims(dims) - @assert dims.qk % nheads == 0 "qk_dim should be divisible by nheads" - @assert dims.v % nheads == 0 "v_dim should be divisible by nheads" + dims.qk % nheads == 0 || throw(ArgumentError("qk_dim = $(dims.qk) should be divisible by nheads = $(nheads)") + dims.v % nheads == 0 || throw(ArgumentError( "v_dim = $(dims.v) should be divisible by nheads = $(nheads)") q_proj = Dense(dims.q_in => dims.qk; bias, init) k_proj = Dense(dims.k_in => dims.qk; bias, init) v_proj = Dense(dims.v_in => dims.v; bias, init) @@ -131,3 +131,42 @@ function (mha::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3, # [α] = [kv_len, q_len, nheads, batch_size] return x, α end + +function Base.show(io::IO, mha::MultiHeadAttention) + qk, q_in = size(mha.q_proj.weight) + qk, k_in = size(mha.k_proj.weight) + v, v_in = size(mha.v_proj.weight) + out, v = size(mha.out_proj.weight) + # @show q_in, k_in, v_in, qk, v, out + print(io, "MultiHeadAttention(") + if q_in == k_in == v_in == qk == v == out + print(io, q_in) + elseif q_in == k_in == v_in && qk == v + print(io, q_in, " => ", qk, " => ", out) + elseif q_in == k_in == v_in + print(io, q_in, " => (", qk, ", ", v,") => ", out) + else + print(io, "(", q_in, ", ", k_in, ", ", v_in, ") => (", qk, ", ", v,") => ", out) + end + print(io, "; nheads=", mha.nheads) + if mha.q_proj.bias === true + print(io, ", bias=true") + end + if mha.attn_drop.p != 0 + print(io, ", dropout_prob=", mha.attn_drop.p) # can't we rename this? + end + print(io, ")") +end + +Base.show(io::IO, ::MIME"text/plain", mha::MultiHeadAttention) = show(io, mha) + +#= + +# Test cases: + +MultiHeadAttention((3, 4, 5) => (6, 7) => 8; nheads=1) +MultiHeadAttention(3 => (6, 7) => 8; nheads=1) +MultiHeadAttention(3 => 6 => 8; nheads=1) +MultiHeadAttention(8; bias=true) + +=#