Skip to content

Commit a9111b3

Browse files
committed
WIP
1 parent 98d12b8 commit a9111b3

6 files changed

+31
-36
lines changed

src/nf/nf_layer.f90

+5-4
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@ module nf_layer
3030
procedure :: get_params
3131
procedure :: get_gradients
3232
procedure :: set_params
33+
procedure :: set_state
3334
procedure :: init
3435
procedure :: print_info
35-
procedure :: reset
3636

3737
! Specific subroutines for different array ranks
3838
procedure, private :: backward_1d
@@ -154,9 +154,10 @@ module subroutine set_params(self, params)
154154
!! Parameters of this layer
155155
end subroutine set_params
156156

157-
module subroutine reset(self)
158-
class(layer), intent(in out) :: self
159-
end subroutine reset
157+
module subroutine set_state(self, state)
158+
class(layer), intent(inout) :: self
159+
real, intent(in), optional :: state(:)
160+
end subroutine set_state
160161

161162
end interface
162163

src/nf/nf_layer_constructors_submodule.f90

-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@ pure module function input1d(layer_size) result(res)
8181
res % initialized = .true.
8282
end function input1d
8383

84-
8584
pure module function input3d(layer_shape) result(res)
8685
integer, intent(in) :: layer_shape(3)
8786
type(layer) :: res

src/nf/nf_layer_submodule.f90

+10-6
Original file line numberDiff line numberDiff line change
@@ -442,14 +442,18 @@ module subroutine set_params(self, params)
442442

443443
end subroutine set_params
444444

445-
module subroutine reset(self)
446-
class(layer), intent(in out) :: self
445+
subroutine set_state(self, state)
446+
class(layer), intent(inout) :: self
447+
real, intent(in), optional :: state(:)
447448

448449
select type (this_layer => self % p)
449450
type is (rnn_layer)
450-
call this_layer % reset()
451-
end select
452-
453-
end subroutine reset
451+
if (present(state)) then
452+
this_layer % state = state
453+
else
454+
this_layer % state = 0
455+
end if
456+
end select
457+
end subroutine set_state
454458

455459
end submodule nf_layer_submodule

src/nf/nf_network.f90

-8
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ module nf_network
2323
procedure :: get_params
2424
procedure :: print_info
2525
procedure :: set_params
26-
procedure :: reset
2726
procedure :: train
2827
procedure :: update
2928

@@ -224,13 +223,6 @@ module subroutine update(self, optimizer, batch_size)
224223
!! Set to `size(input_data, dim=2)` for a batch gradient descent.
225224
end subroutine update
226225

227-
module subroutine reset(self)
228-
!! Reset network state
229-
!!
230-
!! Currently only affect RNN layer type
231-
class(network), intent(in out) :: self
232-
end subroutine reset
233-
234226
end interface
235227

236228
end module nf_network

src/nf/nf_network_submodule.f90

+3-14
Original file line numberDiff line numberDiff line change
@@ -676,23 +676,12 @@ module subroutine update(self, optimizer, batch_size)
676676
type is(conv2d_layer)
677677
this_layer % dw = 0
678678
this_layer % db = 0
679-
end select
680-
end do
681-
682-
end subroutine update
683-
684-
module subroutine reset(self)
685-
class(network), intent(in out) :: self
686-
integer :: n, num_layers
687-
688-
num_layers = size(self % layers)
689-
do n = 2, num_layers
690-
select type(this_layer => self % layers(n) % p)
691679
type is(rnn_layer)
692-
call self % layers(n) % reset()
680+
this_layer % dw = 0
681+
this_layer % db = 0
693682
end select
694683
end do
695684

696-
end subroutine reset
685+
end subroutine update
697686

698687
end submodule nf_network_submodule

src/nf/nf_rnn_layer.f90

+13-3
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ module nf_rnn_layer
1414

1515
type, extends(base_layer) :: rnn_layer
1616

17-
!! Concrete implementation of a dense (fully-connected) layer type
17+
!! Concrete implementation of an RNN (fully-connected) layer type
1818

1919
integer :: input_size
2020
integer :: output_size
@@ -40,7 +40,7 @@ module nf_rnn_layer
4040
procedure :: get_params
4141
procedure :: init
4242
procedure :: set_params
43-
procedure :: reset
43+
procedure :: set_state
4444

4545
end type rnn_layer
4646

@@ -94,7 +94,7 @@ pure module function get_params(self) result(params)
9494
!! Return the parameters (weights and biases) of this layer.
9595
!! The parameters are ordered as weights first, biases second.
9696
class(rnn_layer), intent(in) :: self
97-
!! Dense layer instance
97+
!! RNN layer instance
9898
real, allocatable :: params(:)
9999
!! Parameters of this layer
100100
end function get_params
@@ -137,4 +137,14 @@ end subroutine reset
137137

138138
end interface
139139

140+
subroutine set_state(self, state)
141+
type(rnn_layer), intent(inout) :: self
142+
real, intent(in), optional :: state(:)
143+
if (present(state)) then
144+
self % state = state
145+
else
146+
self % state = 0
147+
end if
148+
end subroutine set_state
149+
140150
end module nf_rnn_layer

0 commit comments

Comments
 (0)