Skip to content

Commit ed8b340

Browse files
OneAddermilancurcicjvdp1
authored
Multihead attention (#199)
* linear2d_layer forward implementation * linear2d_layer: temporarily remove api * Don't expose the concrete layer type via nf * Plumbing of linear2d with input2d and linear2d * linear2d_layer: add flatten2d layer * linear2d_layer: make linear2d layer work with input2d and flatten2d * update cmake * linear2d_layer: remove flatten2d layer * linear2d_layer: remove public api * linear2d_layer: update cmakelists * Add linear2d example * linear2d_layer: remove redundant constructor args * linear2d_layer: make example converge * linear2d_layer: add loss stopping and more iterations * start impementing MultiHeadAttention * scaled dot product attention * combine attention heads * forward (not working) * rearrange attention dimensions in more efficient way * initial forward implementation for multi-head attention * tests for multihead_attention%forward * multihead_attention: move most logic to subroutines (performance) * multihead_attention: update tests * multihead_attention: concurrency * multihead_attention: proof of concept backward (works, but not mathematically correct) * multihead_attention: fix minor scaling issue * multihead_attention: complete backward implementation * multihead_attention: add comments for forward prop * multihead_attention: add tests for backward * multihead_attention: adjust expected test values for updated scaling * multihead_attention: calculate scaling factor only once * multihead_attention: use heap-allocated arrays during back prop * multihead_attention: use heap-allocated arrays in forward * multihead_attention: set values from correct shape to tests * multihead_attention: fix issues with shapes (softmax prime became even more monstruos) * multihead_attention: minor refactoring and optimization * multihead_attention: fix comments * multihead_attention: tests, add checks for attention weights * multihead_attention: remove some of the copypaste comments * multihead_attention: optimize shapes * multihead_attention: params api * multihead_attention: fix incorrect dw bug * multihead_attention: tests for updated parameters * multihead_attention: remove reshape crutches * multihead_attention: rename common forward and backward calls * multihead_attention: tidy mha up * multihead_attention: self attention * multihead_attention: add cross attention * multihead_attention: add more comments * multihead_attention: arrange attention into submodule * multihead_attention: update cmakelists * multihead_attention: update attention in accordance with linear2d * multihead_attention: remove redundand constructor args for attention layers * multihead_attention: use pure and elemental where necessary * multihead_attention: plumbing * multihead_attention: add reference * multihead_attention: remove rebase artifact * multihead_attention: remove redundant args * multihead_attention: update tests * multihead_attention: add the most important lines to tests * multihead_attention: simple MHA example * multihead_attention: update cmake * multihead_attention: remove debug line from tests * multihead_attention: set slightly higher margin for fp imprecision (due to IEEE_DENORMAL) * Rename mha_simple example * Update src/nf/nf_multihead_attention.f90 Co-authored-by: Jeremie Vandenplas <[email protected]> * Update src/nf/nf_multihead_attention.f90 Co-authored-by: Jeremie Vandenplas <[email protected]> * Update src/nf/nf_multihead_attention.f90 Co-authored-by: Jeremie Vandenplas <[email protected]> * Update src/nf/nf_multihead_attention.f90 Co-authored-by: Jeremie Vandenplas <[email protected]> * Tidy up * Add self_attention to the layers table --------- Co-authored-by: milancurcic <[email protected]> Co-authored-by: Jeremie Vandenplas <[email protected]>
1 parent 039638d commit ed8b340

15 files changed

+1188
-5
lines changed

CMakeLists.txt

+4
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ add_library(neural-fortran
2020
src/nf/nf_base_layer.f90
2121
src/nf/nf_conv2d_layer.f90
2222
src/nf/nf_conv2d_layer_submodule.f90
23+
src/nf/nf_cross_attention_layer.f90
2324
src/nf/nf_datasets.f90
2425
src/nf/nf_datasets_submodule.f90
2526
src/nf/nf_datasets_mnist.f90
@@ -45,6 +46,8 @@ add_library(neural-fortran
4546
src/nf/nf_maxpool2d_layer.f90
4647
src/nf/nf_maxpool2d_layer_submodule.f90
4748
src/nf/nf_metrics.f90
49+
src/nf/nf_multihead_attention.f90
50+
src/nf/nf_multihead_attention_submodule.f90
4851
src/nf/nf_network.f90
4952
src/nf/nf_network_submodule.f90
5053
src/nf/nf_optimizers.f90
@@ -53,6 +56,7 @@ add_library(neural-fortran
5356
src/nf/nf_random.f90
5457
src/nf/nf_reshape_layer.f90
5558
src/nf/nf_reshape_layer_submodule.f90
59+
src/nf/nf_self_attention_layer.f90
5660
src/nf/io/nf_io_binary.f90
5761
src/nf/io/nf_io_binary_submodule.f90
5862
src/nf/nf_dropout_layer.f90

README.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,9 @@ Read the paper [here](https://arxiv.org/abs/1902.06714).
3434
| Dropout | `dropout` | `dense`, `flatten`, `input1d` | 1 |||
3535
| Convolutional (2-d) | `conv2d` | `input3d`, `conv2d`, `maxpool2d`, `reshape` | 3 || ✅(*) |
3636
| Max-pooling (2-d) | `maxpool2d` | `input3d`, `conv2d`, `maxpool2d`, `reshape` | 3 |||
37+
| Linear (2-d) | `linear2d` | `input2d`, `linear2d`, `self_attention` | 2 |||
38+
| Self-attention | `self_attention` | `input2d`, `linear2d`, `self_attention` | 2 |||
3739
| Flatten | `flatten` | `input2d`, `input3d`, `conv2d`, `maxpool2d`, `reshape` | 1 |||
38-
| Linear (2-d) | `linear2d` | `input2d`, `linear2d` | 2 |||
3940
| Reshape (1-d to 3-d) | `reshape` | `input1d`, `dense`, `flatten` | 3 |||
4041

4142
(*) See Issue [#145](https://github.com/modern-fortran/neural-fortran/issues/145) regarding non-converging CNN training on the MNIST dataset.

example/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ foreach(execid
66
simple
77
sine
88
quadratic
9+
mha_simple
910
)
1011
add_executable(${execid} ${execid}.f90)
1112
target_link_libraries(${execid} PRIVATE

example/mha_simple.f90

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
program mha_simple
2+
use nf, only: dense, input, network, sgd, self_attention, flatten
3+
implicit none
4+
type(network) :: net
5+
real, allocatable :: x(:, :), y(:)
6+
integer, parameter :: num_iterations = 500
7+
integer :: n
8+
9+
print '("Simple")'
10+
print '(60("="))'
11+
12+
net = network([ &
13+
input(3, 8), &
14+
self_attention(4), &
15+
flatten(), &
16+
dense(2) &
17+
])
18+
19+
call net % print_info()
20+
21+
allocate(x(3, 8))
22+
call random_number(x)
23+
24+
y = [0.123456, 0.246802]
25+
26+
do n = 0, num_iterations
27+
28+
call net % forward(x)
29+
call net % backward(y)
30+
call net % update(optimizer=sgd(learning_rate=1.))
31+
32+
if (mod(n, 50) == 0) &
33+
print '(i4,2(3x,f8.6))', n, net % predict(x)
34+
35+
end do
36+
37+
end program mha_simple

src/nf.f90

+11-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,15 @@ module nf
33
use nf_datasets_mnist, only: label_digits, load_mnist
44
use nf_layer, only: layer
55
use nf_layer_constructors, only: &
6-
conv2d, dense, dropout, flatten, input, linear2d, maxpool2d, reshape
6+
conv2d, &
7+
dense, &
8+
dropout, &
9+
flatten, &
10+
input, &
11+
linear2d, &
12+
maxpool2d, &
13+
reshape, &
14+
self_attention
715
use nf_loss, only: mse, quadratic
816
use nf_metrics, only: corr, maxabs
917
use nf_network, only: network
@@ -12,4 +20,6 @@ module nf
1220
gaussian, linear, relu, leaky_relu, &
1321
sigmoid, softmax, softplus, step, tanhf, &
1422
celu
23+
use nf_linear2d_layer, only: linear2d_layer
24+
use nf_multihead_attention_layer, only: multihead_attention_layer
1525
end module nf

src/nf/nf_cross_attention_layer.f90

+66
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
module nf_cross_attention_layer
2+
use iso_fortran_env, only: stderr => error_unit
3+
use nf_activation, only: softmax
4+
use nf_linear2d_layer, only: linear2d_layer
5+
use nf_multihead_attention_layer, only: multihead_attention_layer
6+
7+
implicit none
8+
9+
type, extends(multihead_attention_layer) :: cross_attention_layer
10+
!! Cross Attention Layer
11+
!! Source:
12+
!! Bahdanau, D. (2014)
13+
!! Neural machine translation by jointly learning to align and translate.
14+
!! https://arxiv.org/pdf/1409.0473
15+
real, allocatable :: gradient(:, :, :)
16+
contains
17+
procedure :: forward
18+
procedure :: backward
19+
procedure :: init
20+
end type cross_attention_layer
21+
22+
interface cross_attention_layer
23+
module function cross_attention_layer_cons(n_heads) result(res)
24+
!! This function returns the `cross_attention_layer` instance.
25+
integer, intent(in) :: sequence_length, model_dimension, n_heads
26+
type(cross_attention_layer) :: res
27+
end function cross_attention_layer_cons
28+
end interface cross_attention_layer
29+
30+
contains
31+
module function cross_attention_layer_cons(n_heads) result(res)
32+
!! This function returns the `cross_attention_layer` instance.
33+
integer, intent(in) :: n_heads
34+
type(cross_attention_layer) :: res
35+
res % n_heads = n_heads
36+
end function cross_attention_layer_cons
37+
38+
pure module subroutine backward(self, input, gradient)
39+
!! Cross Attention Back propagation
40+
class(cross_attention_layer), intent(in out) :: self
41+
real, intent(in) :: input(:, :, :)
42+
real, intent(in) :: gradient(:, :)
43+
44+
call self % common_backward(input(1, :, :), gradient)
45+
self % gradient(1, :, :) = self % query_layer % gradient
46+
self % gradient(2, :, :) = self % key_layer % gradient + self % value_layer % gradient
47+
end subroutine backward
48+
49+
pure module subroutine forward(self, input)
50+
!! Cross Attention Forward propagation
51+
!! Input Shape (kind, sequence_length, model_dimension)
52+
!! where kind is 1 for Query and 2 for Key-Value
53+
class(cross_attention_layer), intent(in out) :: self
54+
real, intent(in) :: input(:, :, :)
55+
56+
call self % common_forward(input(1, :, :), input(2, :, :), input(2, :, :))
57+
end subroutine forward
58+
59+
module subroutine init(self, input_shape)
60+
class(cross_attention_layer), intent(in out) :: self
61+
integer, intent(in) :: input_shape(:)
62+
63+
call self % init_base(input_shape)
64+
allocate(self % gradient(2, self % sequence_length, self % model_dimension))
65+
end subroutine init
66+
end module nf_cross_attention_layer

src/nf/nf_layer_constructors.f90

+20-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,16 @@ module nf_layer_constructors
88
implicit none
99

1010
private
11-
public :: conv2d, dense, dropout, flatten, input, linear2d, maxpool2d, reshape
11+
public :: &
12+
conv2d, &
13+
dense, &
14+
dropout, &
15+
flatten, &
16+
input, &
17+
linear2d, &
18+
maxpool2d, &
19+
reshape, &
20+
self_attention
1221

1322
interface input
1423

@@ -213,6 +222,16 @@ module function linear2d(out_features) result(res)
213222
!! Resulting layer instance
214223
end function linear2d
215224

225+
module function self_attention(num_heads) result(res)
226+
!! Rank-2 (sequence_length, out_features) self attention constructor.
227+
!! sequence_length and model_dimension are determined at layer initialization, based on the
228+
!! output shape of the previous layer.
229+
integer, intent(in) :: num_heads
230+
!! Number of attention heads
231+
type(layer) :: res
232+
!! Resulting layer instance
233+
end function self_attention
234+
216235
end interface
217236

218237
end module nf_layer_constructors

src/nf/nf_layer_constructors_submodule.f90

+9
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
use nf_maxpool2d_layer, only: maxpool2d_layer
1212
use nf_reshape_layer, only: reshape3d_layer
1313
use nf_linear2d_layer, only: linear2d_layer
14+
use nf_self_attention_layer, only: self_attention_layer
1415
use nf_activation, only: activation_function, relu, sigmoid
1516

1617
implicit none
@@ -170,4 +171,12 @@ module function linear2d(out_features) result(res)
170171

171172
end function linear2d
172173

174+
module function self_attention(num_heads) result(res)
175+
integer, intent(in) :: num_heads
176+
type(layer) :: res
177+
178+
res % name = 'self_attention'
179+
allocate(res % p, source=self_attention_layer(num_heads))
180+
end function self_attention
181+
173182
end submodule nf_layer_constructors_submodule

src/nf/nf_layer_submodule.f90

+45-2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
use nf_maxpool2d_layer, only: maxpool2d_layer
1212
use nf_reshape_layer, only: reshape3d_layer
1313
use nf_linear2d_layer, only: linear2d_layer
14+
use nf_self_attention_layer, only: self_attention_layer
1415
use nf_optimizers, only: optimizer_base_type
1516

1617
contains
@@ -57,6 +58,8 @@ pure module subroutine backward_1d(self, previous, gradient)
5758
call this_layer % backward(prev_layer % output, gradient)
5859
type is(linear2d_layer)
5960
call this_layer % backward(prev_layer % output, gradient)
61+
type is(self_attention_layer)
62+
call this_layer % backward(prev_layer % output, gradient)
6063
end select
6164

6265
end select
@@ -79,6 +82,19 @@ pure module subroutine backward_2d(self, previous, gradient)
7982
call this_layer % backward(prev_layer % output, gradient)
8083
type is(linear2d_layer)
8184
call this_layer % backward(prev_layer % output, gradient)
85+
type is(self_attention_layer)
86+
call this_layer % backward(prev_layer % output, gradient)
87+
end select
88+
89+
type is(self_attention_layer)
90+
91+
select type(prev_layer => previous % p)
92+
type is(input2d_layer)
93+
call this_layer % backward(prev_layer % output, gradient)
94+
type is(linear2d_layer)
95+
call this_layer % backward(prev_layer % output, gradient)
96+
type is(self_attention_layer)
97+
call this_layer % backward(prev_layer % output, gradient)
8298
end select
8399

84100
end select
@@ -240,6 +256,20 @@ module subroutine forward(self, input)
240256
call this_layer % forward(prev_layer % output)
241257
type is(linear2d_layer)
242258
call this_layer % forward(prev_layer % output)
259+
type is(self_attention_layer)
260+
call this_layer % forward(prev_layer % output)
261+
end select
262+
263+
type is(self_attention_layer)
264+
265+
! Upstream layers permitted: input2d, linear2d
266+
select type(prev_layer => input % p)
267+
type is(input2d_layer)
268+
call this_layer % forward(prev_layer % output)
269+
type is(linear2d_layer)
270+
call this_layer % forward(prev_layer % output)
271+
type is(self_attention_layer)
272+
call this_layer % forward(prev_layer % output)
243273
end select
244274

245275
end select
@@ -279,6 +309,8 @@ pure module subroutine get_output_2d(self, output)
279309
allocate(output, source=this_layer % output)
280310
type is(linear2d_layer)
281311
allocate(output, source=this_layer % output)
312+
type is(self_attention_layer)
313+
allocate(output, source=this_layer % output)
282314
class default
283315
error stop '2-d output can only be read from an input2d or linear2d layer.'
284316

@@ -322,8 +354,8 @@ impure elemental module subroutine init(self, input)
322354
call this_layer % init(input % layer_shape)
323355
end select
324356

325-
! The shape of conv2d, dropout, flatten, linear2d, or maxpool2d layers
326-
! is not known until we receive an input layer.
357+
! The shape of conv2d, dropout, flatten, linear2d, maxpool2d, or
358+
! self_attention layers is not known until we receive an input layer.
327359
select type(this_layer => self % p)
328360
type is(conv2d_layer)
329361
self % layer_shape = shape(this_layer % output)
@@ -333,6 +365,8 @@ impure elemental module subroutine init(self, input)
333365
self % layer_shape = shape(this_layer % output)
334366
type is(linear2d_layer)
335367
self % layer_shape = shape(this_layer % output)
368+
type is(self_attention_layer)
369+
self % layer_shape = shape(this_layer % output)
336370
type is(maxpool2d_layer)
337371
self % layer_shape = shape(this_layer % output)
338372
end select
@@ -389,6 +423,8 @@ elemental module function get_num_params(self) result(num_params)
389423
num_params = 0
390424
type is (linear2d_layer)
391425
num_params = this_layer % get_num_params()
426+
type is (self_attention_layer)
427+
num_params = this_layer % get_num_params()
392428
class default
393429
error stop 'Unknown layer type.'
394430
end select
@@ -420,6 +456,8 @@ module function get_params(self) result(params)
420456
! No parameters to get.
421457
type is (linear2d_layer)
422458
params = this_layer % get_params()
459+
type is (self_attention_layer)
460+
params = this_layer % get_params()
423461
class default
424462
error stop 'Unknown layer type.'
425463
end select
@@ -451,6 +489,8 @@ module function get_gradients(self) result(gradients)
451489
! No gradients to get.
452490
type is (linear2d_layer)
453491
gradients = this_layer % get_gradients()
492+
type is (self_attention_layer)
493+
gradients = this_layer % get_gradients()
454494
class default
455495
error stop 'Unknown layer type.'
456496
end select
@@ -506,6 +546,9 @@ module subroutine set_params(self, params)
506546
type is (linear2d_layer)
507547
call this_layer % set_params(params)
508548

549+
type is (self_attention_layer)
550+
call this_layer % set_params(params)
551+
509552
type is (maxpool2d_layer)
510553
! No parameters to set.
511554
write(stderr, '(a)') 'Warning: calling set_params() ' &

0 commit comments

Comments
 (0)