-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprojectAndReshapeLayer.m
81 lines (65 loc) · 2.41 KB
/
projectAndReshapeLayer.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
%% Synthetic Biomedical Image Generation with GANs
% Olivier Rukundo, Ph.D. - [email protected] - 2024-04-09
% This script is a modification of the example provided by MathWorks on training generative adversarial networks.
% Original source: https://se.mathworks.com/help/deeplearning/ug/train-generative-adversarial-network.html
classdef projectAndReshapeLayer < nnet.layer.Layer ...
& nnet.layer.Formattable ...
& nnet.layer.Acceleratable
properties
% Layer properties.
OutputSize
end
properties (Learnable)
% Layer learnable parameters.
Weights
Bias
end
methods
function layer = projectAndReshapeLayer(outputSize,NameValueArgs)
% Parse input arguments.
arguments
outputSize
NameValueArgs.Name = "";
end
% Set layer name.
name = NameValueArgs.Name;
layer.Name = name;
% Set layer description.
layer.Description = "Project and reshape to size " + join(string(outputSize));
% Set layer type.
layer.Type = "Project and Reshape";
% Set output size.
layer.OutputSize = outputSize;
end
function layer = initialize(layer,layout)
% Layer output size.
outputSize = layer.OutputSize;
% Initialize fully connect weights.
if isempty(layer.Weights)
% Find number of channels.
idx = finddim(layout,"C");
numChannels = layout.Size(idx);
% Initialize using Glorot.
sz = [prod(outputSize) numChannels];
numOut = prod(outputSize);
numIn = numChannels;
layer.Weights = initializeGlorot(sz,numOut,numIn);
end
% Initialize fully connect bias.
if isempty(layer.Bias)
% Initialize with zeros.
layer.Bias = initializeZeros([prod(outputSize) 1]);
end
end
function Z = predict(layer, X)
% Fully connect.
weights = layer.Weights;
bias = layer.Bias;
X = fullyconnect(X,weights,bias);
% Reshape.
outputSize = layer.OutputSize;
Z = reshape(X,outputSize(1),outputSize(2),outputSize(3),[]);
Z = dlarray(Z,"SSCB");
end
end
end