Skip to content

Commit

Permalink
fixup attention
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Mar 2, 2024
1 parent 2be9099 commit fe72081
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions src/layers/attention.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ function MultiHeadAttention(dims;
dropout_prob = 0.0)

dims = normalize_mha_dims(dims)
dims.qk % nheads == 0 || throw(ArgumentError("qk_dim = $(dims.qk) should be divisible by nheads = $(nheads)")
dims.v % nheads == 0 || throw(ArgumentError( "v_dim = $(dims.v) should be divisible by nheads = $(nheads)")
dims.qk % nheads == 0 || throw(ArgumentError("qk_dim = $(dims.qk) should be divisible by nheads = $(nheads)"))
dims.v % nheads == 0 || throw(ArgumentError( "v_dim = $(dims.v) should be divisible by nheads = $(nheads)"))
q_proj = Dense(dims.q_in => dims.qk; bias, init)
k_proj = Dense(dims.k_in => dims.qk; bias, init)
v_proj = Dense(dims.v_in => dims.v; bias, init)
Expand Down Expand Up @@ -149,7 +149,7 @@ function Base.show(io::IO, mha::MultiHeadAttention)
print(io, "(", q_in, ", ", k_in, ", ", v_in, ") => (", qk, ", ", v,") => ", out)

Check warning on line 149 in src/layers/attention.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/attention.jl#L149

Added line #L149 was not covered by tests
end
print(io, "; nheads=", mha.nheads)
if mha.q_proj.bias === true
if mha.q_proj.bias !== false
print(io, ", bias=true")

Check warning on line 153 in src/layers/attention.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/attention.jl#L151-L153

Added lines #L151 - L153 were not covered by tests
end
if mha.attn_drop.p != 0
Expand All @@ -158,11 +158,10 @@ function Base.show(io::IO, mha::MultiHeadAttention)
print(io, ")")

Check warning on line 158 in src/layers/attention.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/attention.jl#L158

Added line #L158 was not covered by tests
end

Base.show(io::IO, ::MIME"text/plain", mha::MultiHeadAttention) = show(io, mha)

#=
# Test cases:
# Test cases for printing:
MultiHeadAttention((3, 4, 5) => (6, 7) => 8; nheads=1)
MultiHeadAttention(3 => (6, 7) => 8; nheads=1)
Expand Down

0 comments on commit fe72081

Please sign in to comment.