Skip to content

Commit 2ed7b6a

Browse files
authored
Multihead Attention Fixes (#209)
* multihead_attention_optimization: allocate in init * multihead_attention_optimization: remove last runtime allocation * multihead_attention_optimization: make attention mask actually useable * multihead_attention_optimization: tests cleanup * multihead_attention_optimization: cleanup * multihead_attention_optimization: refactoring, split methods even more (will be needed for llama attention) * multihead_attention_optimization: make attributes public * multihead_attention_optimization: move heads separation out of sdpa backward * multihead_attention_optimization: add attention mask to self_attention
1 parent 1db0258 commit 2ed7b6a

4 files changed

+217
-127
lines changed

src/nf/nf_multihead_attention.f90

+34-6
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,25 @@ module nf_multihead_attention_layer
3939
real, allocatable :: k_input(:, :)
4040
real, allocatable :: v_input(:, :)
4141
real, allocatable :: o_input(:, :)
42+
43+
! temporary storages for forward and backward passes
44+
real, allocatable :: normalized_attention(:, :, :)
45+
real, allocatable :: q_or_dq(:, :, :)
46+
real, allocatable :: k_or_dk(:, :, :)
47+
real, allocatable :: v_or_dv(:, :, :)
48+
real, allocatable :: d_output(:, :, :)
49+
real, allocatable :: v_heads(:, :, :)
50+
real, allocatable :: k_heads(:, :, :)
51+
real, allocatable :: q_heads(:, :, :)
52+
real, allocatable :: d_sdpa(:, :)
53+
real, allocatable :: jacobian(:, :)
54+
real, allocatable :: d_normalize(:, :, :)
4255
contains
4356

4457
procedure :: common_backward
4558
procedure :: common_forward
59+
procedure :: sdpa_forward
60+
procedure :: sdpa_backward
4661
procedure :: get_num_params
4762
procedure :: get_params
4863
procedure :: get_gradients
@@ -68,25 +83,38 @@ end function multihead_attention_layer_cons
6883

6984
interface
7085

71-
pure module subroutine common_backward(self, input, gradient)
86+
pure module subroutine common_backward(self, input, gradient, attention_mask)
7287
!! General backprop for MultiHead Attention mechanism
7388
!! Might be used for both Self and Cross Attention
7489
!! Self Attention: sum output gradients
7590
!! Cross Attention: use them separately
7691
class(multihead_attention_layer), intent(in out) :: self
7792
real, intent(in) :: input(:, :)
7893
real, intent(in) :: gradient(:, :)
94+
real, optional, intent(in) :: attention_mask(:, :)
7995
end subroutine common_backward
8096

81-
pure module subroutine common_forward(self, query, key, value)
97+
pure module subroutine common_forward(self, query, key, value, attention_mask)
8298
!! General forward propagation for MultiHead Attention Mechanism
8399
!! Might be used for both Self and Cross Attention
84100
!! Self Attention: pass the same value thrice
85101
!! Cross Attention: pass three values for your query, key and value
86102
class(multihead_attention_layer), intent(in out) :: self
87103
real, intent(in) :: query(:, :), key(:, :), value(:, :)
104+
real, optional, intent(in) :: attention_mask(:, :)
88105
end subroutine common_forward
89106

107+
pure module subroutine sdpa_forward(self, attention_mask)
108+
class(multihead_attention_layer), intent(in out) :: self
109+
real, intent(in), optional :: attention_mask(:, :)
110+
end subroutine sdpa_forward
111+
112+
pure module subroutine sdpa_backward(self, gradient, attention_mask)
113+
class(multihead_attention_layer), intent(in out) :: self
114+
real, intent(in) :: gradient(:, :)
115+
real, intent(in), optional :: attention_mask(:, :)
116+
end subroutine sdpa_backward
117+
90118
pure module subroutine init(self, input_shape)
91119
!! Initialize the layer data structures.
92120
!!
@@ -119,7 +147,7 @@ pure module subroutine normalize_attention_matrix(self, attention_mask)
119147
!! Output dims: sequence_length, sequence_length, n_heads
120148
class(multihead_attention_layer), intent(in out) :: self
121149
!! (sequence_length, sequence_length, n_heads)
122-
real, optional, intent(in) :: attention_mask(:, :, :)
150+
real, optional, intent(in) :: attention_mask(:, :)
123151
!! (sequence_length, sequence_length, n_heads)
124152
end subroutine normalize_attention_matrix
125153

@@ -143,18 +171,18 @@ elemental module function get_num_params(self) result(num_params)
143171
end function get_num_params
144172

145173
module function get_params(self) result(params)
146-
class(multihead_attention_layer), intent(in), target :: self
174+
class(multihead_attention_layer), intent(in) :: self
147175
real, allocatable :: params(:)
148176
end function get_params
149177

150178
module function get_gradients(self) result(gradients)
151-
class(multihead_attention_layer), intent(in), target :: self
179+
class(multihead_attention_layer), intent(in) :: self
152180
real, allocatable :: gradients(:)
153181
end function get_gradients
154182

155183
module subroutine set_params(self, params)
156184
class(multihead_attention_layer), intent(in out) :: self
157-
real, intent(in), target :: params(:)
185+
real, intent(in) :: params(:)
158186
end subroutine set_params
159187

160188
module subroutine init_base(self, input_shape)

0 commit comments

Comments
 (0)