11
11
use nf_maxpool2d_layer, only: maxpool2d_layer
12
12
use nf_reshape_layer, only: reshape3d_layer
13
13
use nf_linear2d_layer, only: linear2d_layer
14
+ use nf_self_attention_layer, only: self_attention_layer
14
15
use nf_optimizers, only: optimizer_base_type
15
16
16
17
contains
@@ -57,6 +58,8 @@ pure module subroutine backward_1d(self, previous, gradient)
57
58
call this_layer % backward(prev_layer % output, gradient)
58
59
type is (linear2d_layer)
59
60
call this_layer % backward(prev_layer % output, gradient)
61
+ type is (self_attention_layer)
62
+ call this_layer % backward(prev_layer % output, gradient)
60
63
end select
61
64
62
65
end select
@@ -79,6 +82,19 @@ pure module subroutine backward_2d(self, previous, gradient)
79
82
call this_layer % backward(prev_layer % output, gradient)
80
83
type is (linear2d_layer)
81
84
call this_layer % backward(prev_layer % output, gradient)
85
+ type is (self_attention_layer)
86
+ call this_layer % backward(prev_layer % output, gradient)
87
+ end select
88
+
89
+ type is (self_attention_layer)
90
+
91
+ select type (prev_layer = > previous % p)
92
+ type is (input2d_layer)
93
+ call this_layer % backward(prev_layer % output, gradient)
94
+ type is (linear2d_layer)
95
+ call this_layer % backward(prev_layer % output, gradient)
96
+ type is (self_attention_layer)
97
+ call this_layer % backward(prev_layer % output, gradient)
82
98
end select
83
99
84
100
end select
@@ -240,6 +256,20 @@ module subroutine forward(self, input)
240
256
call this_layer % forward(prev_layer % output)
241
257
type is (linear2d_layer)
242
258
call this_layer % forward(prev_layer % output)
259
+ type is (self_attention_layer)
260
+ call this_layer % forward(prev_layer % output)
261
+ end select
262
+
263
+ type is (self_attention_layer)
264
+
265
+ ! Upstream layers permitted: input2d, linear2d
266
+ select type (prev_layer = > input % p)
267
+ type is (input2d_layer)
268
+ call this_layer % forward(prev_layer % output)
269
+ type is (linear2d_layer)
270
+ call this_layer % forward(prev_layer % output)
271
+ type is (self_attention_layer)
272
+ call this_layer % forward(prev_layer % output)
243
273
end select
244
274
245
275
end select
@@ -279,6 +309,8 @@ pure module subroutine get_output_2d(self, output)
279
309
allocate (output, source= this_layer % output)
280
310
type is (linear2d_layer)
281
311
allocate (output, source= this_layer % output)
312
+ type is (self_attention_layer)
313
+ allocate (output, source= this_layer % output)
282
314
class default
283
315
error stop ' 2-d output can only be read from an input2d or linear2d layer.'
284
316
@@ -322,8 +354,8 @@ impure elemental module subroutine init(self, input)
322
354
call this_layer % init(input % layer_shape)
323
355
end select
324
356
325
- ! The shape of conv2d, dropout, flatten, linear2d, or maxpool2d layers
326
- ! is not known until we receive an input layer.
357
+ ! The shape of conv2d, dropout, flatten, linear2d, maxpool2d, or
358
+ ! self_attention layers is not known until we receive an input layer.
327
359
select type (this_layer = > self % p)
328
360
type is (conv2d_layer)
329
361
self % layer_shape = shape (this_layer % output)
@@ -333,6 +365,8 @@ impure elemental module subroutine init(self, input)
333
365
self % layer_shape = shape (this_layer % output)
334
366
type is (linear2d_layer)
335
367
self % layer_shape = shape (this_layer % output)
368
+ type is (self_attention_layer)
369
+ self % layer_shape = shape (this_layer % output)
336
370
type is (maxpool2d_layer)
337
371
self % layer_shape = shape (this_layer % output)
338
372
end select
@@ -389,6 +423,8 @@ elemental module function get_num_params(self) result(num_params)
389
423
num_params = 0
390
424
type is (linear2d_layer)
391
425
num_params = this_layer % get_num_params()
426
+ type is (self_attention_layer)
427
+ num_params = this_layer % get_num_params()
392
428
class default
393
429
error stop ' Unknown layer type.'
394
430
end select
@@ -420,6 +456,8 @@ module function get_params(self) result(params)
420
456
! No parameters to get.
421
457
type is (linear2d_layer)
422
458
params = this_layer % get_params()
459
+ type is (self_attention_layer)
460
+ params = this_layer % get_params()
423
461
class default
424
462
error stop ' Unknown layer type.'
425
463
end select
@@ -451,6 +489,8 @@ module function get_gradients(self) result(gradients)
451
489
! No gradients to get.
452
490
type is (linear2d_layer)
453
491
gradients = this_layer % get_gradients()
492
+ type is (self_attention_layer)
493
+ gradients = this_layer % get_gradients()
454
494
class default
455
495
error stop ' Unknown layer type.'
456
496
end select
@@ -506,6 +546,9 @@ module subroutine set_params(self, params)
506
546
type is (linear2d_layer)
507
547
call this_layer % set_params(params)
508
548
549
+ type is (self_attention_layer)
550
+ call this_layer % set_params(params)
551
+
509
552
type is (maxpool2d_layer)
510
553
! No parameters to set.
511
554
write (stderr, ' (a)' ) ' Warning: calling set_params() ' &
0 commit comments