Skip to content

Commit 90d80d2

Browse files
committed
embedding_layer: update constructor and tests
1 parent d6bbbac commit 90d80d2

File tree

3 files changed

+29
-12
lines changed

3 files changed

+29
-12
lines changed

src/nf/nf_layer_constructors.f90

+2-1
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ module function linear2d(out_features) result(res)
195195
!! Resulting layer instance
196196
end function linear2d
197197

198-
module function embedding(sequence_length, vocab_size, model_dimension) result(res)
198+
module function embedding(sequence_length, vocab_size, model_dimension, positional) result(res)
199199
!! Embedding layer constructor.
200200
!!
201201
!! This layer is for inputting token indices from the dictionary to the network.
@@ -205,6 +205,7 @@ module function embedding(sequence_length, vocab_size, model_dimension) result(r
205205
!! `vocab_size`: length of token vocabulary
206206
!! `model_dimension`: size of target embeddings
207207
integer, intent(in) :: sequence_length, vocab_size, model_dimension
208+
integer, optional, intent(in) :: positional
208209
type(layer) :: res
209210
end function embedding
210211

src/nf/nf_layer_constructors_submodule.f90

+3-2
Original file line numberDiff line numberDiff line change
@@ -162,12 +162,13 @@ module function linear2d(out_features) result(res)
162162
end function linear2d
163163

164164

165-
module function embedding(sequence_length, vocab_size, model_dimension) result(res)
165+
module function embedding(sequence_length, vocab_size, model_dimension, positional) result(res)
166166
integer, intent(in) :: sequence_length, vocab_size, model_dimension
167+
integer, optional, intent(in) :: positional
167168
type(layer) :: res
168169
type(embedding_layer) :: embedding_layer_instance
169170

170-
embedding_layer_instance = embedding_layer(vocab_size, model_dimension)
171+
embedding_layer_instance = embedding_layer(vocab_size, model_dimension, positional)
171172
call embedding_layer_instance % init([sequence_length])
172173
res % name = 'embedding'
173174
res % layer_shape = [sequence_length, model_dimension]

test/test_embedding_layer.f90

+24-9
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
program test_embedding_layer
22
use iso_fortran_env, only: stderr => error_unit
33
use nf_embedding_layer, only: embedding_layer
4+
use nf_layer, only: layer
5+
use nf_layer_constructors, only: embedding_constructor => embedding
46
implicit none
57

68
logical :: ok = .true.
9+
integer :: sample_input(3) = [2, 1, 3]
710

8-
call test_simple(ok)
9-
call test_positional_trigonometric(ok)
10-
call test_positional_absolute(ok)
11+
call test_simple(ok, sample_input)
12+
call test_positional_trigonometric(ok, sample_input)
13+
call test_positional_absolute(ok, sample_input)
1114

1215
if (ok) then
1316
print '(a)', 'test_embedding_layer: All tests passed.'
@@ -17,10 +20,10 @@ program test_embedding_layer
1720
end if
1821

1922
contains
20-
subroutine test_simple(ok)
23+
subroutine test_simple(ok, sample_input)
2124
logical, intent(in out) :: ok
25+
integer, intent(in) :: sample_input(:)
2226

23-
integer :: sample_input(3) = [2, 1, 3]
2427
real :: sample_gradient(3, 2) = reshape([0.1, 0.2, 0.3, 0.4, 0.6, 0.6], [3, 2])
2528
real :: output_flat(6)
2629
real :: expected_output_flat(6) = reshape([0.3, 0.1, 0.5, 0.4, 0.2, 0.6], [6])
@@ -48,10 +51,10 @@ subroutine test_simple(ok)
4851
end if
4952
end subroutine test_simple
5053

51-
subroutine test_positional_trigonometric(ok)
54+
subroutine test_positional_trigonometric(ok, sample_input)
5255
logical, intent(in out) :: ok
56+
integer, intent(in) :: sample_input(:)
5357

54-
integer :: sample_input(3) = [2, 1, 3]
5558
real :: output_flat(12)
5659
real :: expected_output_flat(12) = reshape([&
5760
0.3, 0.941471, 1.4092975,&
@@ -82,10 +85,10 @@ subroutine test_positional_trigonometric(ok)
8285
end if
8386
end subroutine test_positional_trigonometric
8487

85-
subroutine test_positional_absolute(ok)
88+
subroutine test_positional_absolute(ok, sample_input)
8689
logical, intent(in out) :: ok
90+
integer, intent(in) :: sample_input(:)
8791

88-
integer :: sample_input(3) = [2, 1, 3]
8992
real :: output_flat(12)
9093
real :: expected_output_flat(12) = reshape([&
9194
0.3, 1.1, 2.5,&
@@ -115,4 +118,16 @@ subroutine test_positional_absolute(ok)
115118
write(stderr, '(a)') 'absolute positional encoding returned incorrect values.. failed'
116119
end if
117120
end subroutine test_positional_absolute
121+
122+
subroutine test_embedding_constructor(ok, sample_input)
123+
logical, intent(in out) :: ok
124+
integer, intent(in) :: sample_input(:)
125+
126+
type(layer) :: embedding_constructed
127+
128+
embedding_constructed = embedding_constructor(sequence_length=3, vocab_size=5, model_dimension=4)
129+
embedding_constructed = embedding_constructor(sequence_length=3, vocab_size=5, model_dimension=4, positional=0)
130+
embedding_constructed = embedding_constructor(sequence_length=3, vocab_size=5, model_dimension=4, positional=1)
131+
embedding_constructed = embedding_constructor(sequence_length=3, vocab_size=5, model_dimension=4, positional=2)
132+
end subroutine test_embedding_constructor
118133
end program test_embedding_layer

0 commit comments

Comments
 (0)