@@ -39,10 +39,25 @@ module nf_multihead_attention_layer
39
39
real , allocatable :: k_input(:, :)
40
40
real , allocatable :: v_input(:, :)
41
41
real , allocatable :: o_input(:, :)
42
+
43
+ ! temporary storages for forward and backward passes
44
+ real , allocatable :: normalized_attention(:, :, :)
45
+ real , allocatable :: q_or_dq(:, :, :)
46
+ real , allocatable :: k_or_dk(:, :, :)
47
+ real , allocatable :: v_or_dv(:, :, :)
48
+ real , allocatable :: d_output(:, :, :)
49
+ real , allocatable :: v_heads(:, :, :)
50
+ real , allocatable :: k_heads(:, :, :)
51
+ real , allocatable :: q_heads(:, :, :)
52
+ real , allocatable :: d_sdpa(:, :)
53
+ real , allocatable :: jacobian(:, :)
54
+ real , allocatable :: d_normalize(:, :, :)
42
55
contains
43
56
44
57
procedure :: common_backward
45
58
procedure :: common_forward
59
+ procedure :: sdpa_forward
60
+ procedure :: sdpa_backward
46
61
procedure :: get_num_params
47
62
procedure :: get_params
48
63
procedure :: get_gradients
@@ -68,25 +83,38 @@ end function multihead_attention_layer_cons
68
83
69
84
interface
70
85
71
- pure module subroutine common_backward(self, input, gradient)
86
+ pure module subroutine common_backward(self, input, gradient, attention_mask )
72
87
! ! General backprop for MultiHead Attention mechanism
73
88
! ! Might be used for both Self and Cross Attention
74
89
! ! Self Attention: sum output gradients
75
90
! ! Cross Attention: use them separately
76
91
class(multihead_attention_layer), intent (in out ) :: self
77
92
real , intent (in ) :: input(:, :)
78
93
real , intent (in ) :: gradient(:, :)
94
+ real , optional , intent (in ) :: attention_mask(:, :)
79
95
end subroutine common_backward
80
96
81
- pure module subroutine common_forward(self, query, key, value)
97
+ pure module subroutine common_forward(self, query, key, value, attention_mask )
82
98
! ! General forward propagation for MultiHead Attention Mechanism
83
99
! ! Might be used for both Self and Cross Attention
84
100
! ! Self Attention: pass the same value thrice
85
101
! ! Cross Attention: pass three values for your query, key and value
86
102
class(multihead_attention_layer), intent (in out ) :: self
87
103
real , intent (in ) :: query(:, :), key(:, :), value(:, :)
104
+ real , optional , intent (in ) :: attention_mask(:, :)
88
105
end subroutine common_forward
89
106
107
+ pure module subroutine sdpa_forward(self, attention_mask)
108
+ class(multihead_attention_layer), intent (in out ) :: self
109
+ real , intent (in ), optional :: attention_mask(:, :)
110
+ end subroutine sdpa_forward
111
+
112
+ pure module subroutine sdpa_backward(self, gradient, attention_mask)
113
+ class(multihead_attention_layer), intent (in out ) :: self
114
+ real , intent (in ) :: gradient(:, :)
115
+ real , intent (in ), optional :: attention_mask(:, :)
116
+ end subroutine sdpa_backward
117
+
90
118
pure module subroutine init(self, input_shape)
91
119
! ! Initialize the layer data structures.
92
120
! !
@@ -119,7 +147,7 @@ pure module subroutine normalize_attention_matrix(self, attention_mask)
119
147
! ! Output dims: sequence_length, sequence_length, n_heads
120
148
class(multihead_attention_layer), intent (in out ) :: self
121
149
! ! (sequence_length, sequence_length, n_heads)
122
- real , optional , intent (in ) :: attention_mask(:, :, : )
150
+ real , optional , intent (in ) :: attention_mask(:, :)
123
151
! ! (sequence_length, sequence_length, n_heads)
124
152
end subroutine normalize_attention_matrix
125
153
@@ -143,18 +171,18 @@ elemental module function get_num_params(self) result(num_params)
143
171
end function get_num_params
144
172
145
173
module function get_params (self ) result(params)
146
- class(multihead_attention_layer), intent (in ), target :: self
174
+ class(multihead_attention_layer), intent (in ) :: self
147
175
real , allocatable :: params(:)
148
176
end function get_params
149
177
150
178
module function get_gradients (self ) result(gradients)
151
- class(multihead_attention_layer), intent (in ), target :: self
179
+ class(multihead_attention_layer), intent (in ) :: self
152
180
real , allocatable :: gradients(:)
153
181
end function get_gradients
154
182
155
183
module subroutine set_params (self , params )
156
184
class(multihead_attention_layer), intent (in out ) :: self
157
- real , intent (in ), target :: params(:)
185
+ real , intent (in ) :: params(:)
158
186
end subroutine set_params
159
187
160
188
module subroutine init_base (self , input_shape )
0 commit comments