1
1
program test_embedding_layer
2
2
use iso_fortran_env, only: stderr = > error_unit
3
3
use nf_embedding_layer, only: embedding_layer
4
+ use nf_layer, only: layer
5
+ use nf_layer_constructors, only: embedding_constructor = > embedding
4
6
implicit none
5
7
6
8
logical :: ok = .true.
9
+ integer :: sample_input(3 ) = [2 , 1 , 3 ]
7
10
8
- call test_simple(ok)
9
- call test_positional_trigonometric(ok)
10
- call test_positional_absolute(ok)
11
+ call test_simple(ok, sample_input )
12
+ call test_positional_trigonometric(ok, sample_input )
13
+ call test_positional_absolute(ok, sample_input )
11
14
12
15
if (ok) then
13
16
print ' (a)' , ' test_embedding_layer: All tests passed.'
@@ -17,10 +20,10 @@ program test_embedding_layer
17
20
end if
18
21
19
22
contains
20
- subroutine test_simple (ok )
23
+ subroutine test_simple (ok , sample_input )
21
24
logical , intent (in out ) :: ok
25
+ integer , intent (in ) :: sample_input(:)
22
26
23
- integer :: sample_input(3 ) = [2 , 1 , 3 ]
24
27
real :: sample_gradient(3 , 2 ) = reshape ([0.1 , 0.2 , 0.3 , 0.4 , 0.6 , 0.6 ], [3 , 2 ])
25
28
real :: output_flat(6 )
26
29
real :: expected_output_flat(6 ) = reshape ([0.3 , 0.1 , 0.5 , 0.4 , 0.2 , 0.6 ], [6 ])
@@ -48,10 +51,10 @@ subroutine test_simple(ok)
48
51
end if
49
52
end subroutine test_simple
50
53
51
- subroutine test_positional_trigonometric (ok )
54
+ subroutine test_positional_trigonometric (ok , sample_input )
52
55
logical , intent (in out ) :: ok
56
+ integer , intent (in ) :: sample_input(:)
53
57
54
- integer :: sample_input(3 ) = [2 , 1 , 3 ]
55
58
real :: output_flat(12 )
56
59
real :: expected_output_flat(12 ) = reshape ([&
57
60
0.3 , 0.941471 , 1.4092975 ,&
@@ -82,10 +85,10 @@ subroutine test_positional_trigonometric(ok)
82
85
end if
83
86
end subroutine test_positional_trigonometric
84
87
85
- subroutine test_positional_absolute (ok )
88
+ subroutine test_positional_absolute (ok , sample_input )
86
89
logical , intent (in out ) :: ok
90
+ integer , intent (in ) :: sample_input(:)
87
91
88
- integer :: sample_input(3 ) = [2 , 1 , 3 ]
89
92
real :: output_flat(12 )
90
93
real :: expected_output_flat(12 ) = reshape ([&
91
94
0.3 , 1.1 , 2.5 ,&
@@ -115,4 +118,16 @@ subroutine test_positional_absolute(ok)
115
118
write (stderr, ' (a)' ) ' absolute positional encoding returned incorrect values.. failed'
116
119
end if
117
120
end subroutine test_positional_absolute
121
+
122
+ subroutine test_embedding_constructor (ok , sample_input )
123
+ logical , intent (in out ) :: ok
124
+ integer , intent (in ) :: sample_input(:)
125
+
126
+ type (layer) :: embedding_constructed
127
+
128
+ embedding_constructed = embedding_constructor(sequence_length= 3 , vocab_size= 5 , model_dimension= 4 )
129
+ embedding_constructed = embedding_constructor(sequence_length= 3 , vocab_size= 5 , model_dimension= 4 , positional= 0 )
130
+ embedding_constructed = embedding_constructor(sequence_length= 3 , vocab_size= 5 , model_dimension= 4 , positional= 1 )
131
+ embedding_constructed = embedding_constructor(sequence_length= 3 , vocab_size= 5 , model_dimension= 4 , positional= 2 )
132
+ end subroutine test_embedding_constructor
118
133
end program test_embedding_layer
0 commit comments