From fe7208168e47faeabfada4991103b62279e33f7d Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 2 Mar 2024 14:39:36 -0500 Subject: [PATCH] fixup attention --- src/layers/attention.jl | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/layers/attention.jl b/src/layers/attention.jl index e058088156..d4a33283d9 100644 --- a/src/layers/attention.jl +++ b/src/layers/attention.jl @@ -83,8 +83,8 @@ function MultiHeadAttention(dims; dropout_prob = 0.0) dims = normalize_mha_dims(dims) - dims.qk % nheads == 0 || throw(ArgumentError("qk_dim = $(dims.qk) should be divisible by nheads = $(nheads)") - dims.v % nheads == 0 || throw(ArgumentError( "v_dim = $(dims.v) should be divisible by nheads = $(nheads)") + dims.qk % nheads == 0 || throw(ArgumentError("qk_dim = $(dims.qk) should be divisible by nheads = $(nheads)")) + dims.v % nheads == 0 || throw(ArgumentError( "v_dim = $(dims.v) should be divisible by nheads = $(nheads)")) q_proj = Dense(dims.q_in => dims.qk; bias, init) k_proj = Dense(dims.k_in => dims.qk; bias, init) v_proj = Dense(dims.v_in => dims.v; bias, init) @@ -149,7 +149,7 @@ function Base.show(io::IO, mha::MultiHeadAttention) print(io, "(", q_in, ", ", k_in, ", ", v_in, ") => (", qk, ", ", v,") => ", out) end print(io, "; nheads=", mha.nheads) - if mha.q_proj.bias === true + if mha.q_proj.bias !== false print(io, ", bias=true") end if mha.attn_drop.p != 0 @@ -158,11 +158,10 @@ function Base.show(io::IO, mha::MultiHeadAttention) print(io, ")") end -Base.show(io::IO, ::MIME"text/plain", mha::MultiHeadAttention) = show(io, mha) #= -# Test cases: +# Test cases for printing: MultiHeadAttention((3, 4, 5) => (6, 7) => 8; nheads=1) MultiHeadAttention(3 => (6, 7) => 8; nheads=1)