diff --git a/src/layers/invertible_layer_hyperbolic.jl b/src/layers/invertible_layer_hyperbolic.jl index bc4e8fe9..80a4ffb6 100644 --- a/src/layers/invertible_layer_hyperbolic.jl +++ b/src/layers/invertible_layer_hyperbolic.jl @@ -53,10 +53,13 @@ Create an invertible hyperbolic coupling layer. See also: [`get_params`](@ref), [`clear_grad!`](@ref) """ -struct HyperbolicLayer{S, P, A} <: NeuralNetLayer +struct HyperbolicLayer <: NeuralNetLayer W::Parameter b::Parameter α::Float32 + stride::Int + pad::Int + action end @Flux.functor HyperbolicLayer @@ -75,7 +78,7 @@ function HyperbolicLayer(n_in::Int64, kernel::Int64, stride::Int64, W = Parameter(glorot_uniform(k..., n_out, n_hidden)) b = Parameter(zeros(Float32, n_hidden)) - return HyperbolicLayer{stride, pad, action}(W, b, α) + return HyperbolicLayer(W, b, α,stride, pad, action) end HyperbolicLayer3D(args...; kw...) = HyperbolicLayer(args...; kw..., ndims=3) @@ -91,7 +94,7 @@ function HyperbolicLayer(W::AbstractArray{Float32, N}, b::AbstractArray{Float32, W = Parameter(W) b = Parameter(b) - return HyperbolicLayer{stride, pad, action}(W, b, α) + return HyperbolicLayer(W, b, α, stride, pad, action) end HyperbolicLayer3D(W::AbstractArray{Float32, N}, b, stride, pad; @@ -101,16 +104,16 @@ HyperbolicLayer3D(W::AbstractArray{Float32, N}, b, stride, pad; ################################################# # Forward pass -function forward(X_prev_in, X_curr_in, HL::HyperbolicLayer{s, p, a}) where {s, p, a} +function forward(X_prev_in, X_curr_in, HL::HyperbolicLayer) # Change dimensions - if a == 0 + if HL.action == 0 X_prev = identity(X_prev_in) X_curr = identity(X_curr_in) - elseif a == 1 + elseif HL.action == 1 X_prev = wavelet_unsqueeze(X_prev_in) X_curr = wavelet_unsqueeze(X_curr_in) - elseif a == -1 + elseif HL.action == -1 X_prev = wavelet_squeeze(X_prev_in) X_curr = wavelet_squeeze(X_curr_in) else @@ -118,7 +121,7 @@ function forward(X_prev_in, X_curr_in, HL::HyperbolicLayer{s, p, a}) where {s, p end # Symmetric convolution w/ relu activation - cdims = DCDims(X_curr, HL.W.data; stride=s, padding=p) + cdims = DCDims(X_curr, HL.W.data; stride=HL.stride, padding=HL.pad) if length(size(X_curr)) == 4 X_conv = conv(X_curr, HL.W.data, cdims) .+ reshape(HL.b.data, 1, 1, :, 1) else @@ -134,8 +137,8 @@ function forward(X_prev_in, X_curr_in, HL::HyperbolicLayer{s, p, a}) where {s, p end # Inverse pass -function inverse(X_curr, X_new, HL::HyperbolicLayer{s, p, a}; save=false) where {s, p, a} - cdims = DCDims(X_curr, HL.W.data; stride=s, padding=p) +function inverse(X_curr, X_new, HL::HyperbolicLayer; save=false) + cdims = DCDims(X_curr, HL.W.data; stride=HL.stride, padding=HL.pad) # Symmetric convolution w/ relu activation if length(size(X_curr)) == 4 X_conv = conv(X_curr, HL.W.data, cdims) .+ reshape(HL.b.data, 1, 1, :, 1) @@ -149,13 +152,13 @@ function inverse(X_curr, X_new, HL::HyperbolicLayer{s, p, a}; save=false) where X_prev = 2*X_curr - X_new + HL.α*X_convT # Change dimensions - if a == 0 + if HL.action == 0 X_prev_in = identity(X_prev) X_curr_in = identity(X_curr) - elseif a == -1 + elseif HL.action == -1 X_prev_in = wavelet_unsqueeze(X_prev) X_curr_in = wavelet_unsqueeze(X_curr) - elseif a == 1 + elseif HL.action == 1 X_prev_in = wavelet_squeeze(X_prev) X_curr_in = wavelet_squeeze(X_curr) else @@ -170,13 +173,13 @@ function inverse(X_curr, X_new, HL::HyperbolicLayer{s, p, a}; save=false) where end # Backward pass -function backward(ΔX_curr, ΔX_new, X_curr, X_new, HL::HyperbolicLayer{s, p, a}; set_grad::Bool=true) where {s, p, a} +function backward(ΔX_curr, ΔX_new, X_curr, X_new, HL::HyperbolicLayer; set_grad::Bool=true) # Recompute forward states X_prev_in, X_curr_in, X_conv, X_relu = inverse(X_curr, X_new, HL; save=true) # Backpropagate data residual and compute gradients - cdims = DCDims(X_curr, HL.W.data; stride=s, padding=p) + cdims = DCDims(X_curr, HL.W.data; stride=HL.stride, padding=HL.pad) ΔX_convT = copy(ΔX_new) ΔX_relu = -HL.α*conv(ΔX_convT, HL.W.data, cdims) ΔW = -HL.α*∇conv_filter(ΔX_convT, X_relu, cdims) @@ -201,13 +204,13 @@ function backward(ΔX_curr, ΔX_new, X_curr, X_new, HL::HyperbolicLayer{s, p, a} end # Change dimensions - if a == 0 + if HL.action == 0 ΔX_prev_in = identity(ΔX_prev) ΔX_curr_in = identity(ΔX_curr) - elseif a == -1 + elseif HL.action == -1 ΔX_prev_in = wavelet_unsqueeze(ΔX_prev) ΔX_curr_in = wavelet_unsqueeze(ΔX_curr) - elseif a == 1 + elseif HL.action == 1 ΔX_prev_in = wavelet_squeeze(ΔX_prev) ΔX_curr_in = wavelet_squeeze(ΔX_curr) else @@ -221,20 +224,20 @@ end ## Jacobian utilities # 2D -function jacobian(ΔX_prev_in, ΔX_curr_in, Δθ, X_prev_in, X_curr_in, HL::HyperbolicLayer{s, p, a}) where {s, p, a} +function jacobian(ΔX_prev_in, ΔX_curr_in, Δθ, X_prev_in, X_curr_in, HL::HyperbolicLayer) # Change dimensions - if a == 0 + if HL.action == 0 X_prev = identity(X_prev_in) X_curr = identity(X_curr_in) ΔX_prev = identity(ΔX_prev_in) ΔX_curr = identity(ΔX_curr_in) - elseif a == 1 + elseif HL.action == 1 X_prev = wavelet_unsqueeze(X_prev_in) X_curr = wavelet_unsqueeze(X_curr_in) ΔX_prev = wavelet_unsqueeze(ΔX_prev_in) ΔX_curr = wavelet_unsqueeze(ΔX_curr_in) - elseif a == -1 + elseif HL.action == -1 X_prev = wavelet_squeeze(X_prev_in) X_curr = wavelet_squeeze(X_curr_in) ΔX_prev = wavelet_squeeze(ΔX_prev_in) @@ -243,7 +246,7 @@ function jacobian(ΔX_prev_in, ΔX_curr_in, Δθ, X_prev_in, X_curr_in, HL::Hype throw("Specified operation not defined.") end - cdims = DCDims(X_curr, HL.W.data; stride=s, padding=p) + cdims = DCDims(X_curr, HL.W.data; stride=HL.stride, padding=HL.pad) # Symmetric convolution w/ relu activation if length(size(X_curr)) == 4 X_conv = conv(X_curr, HL.W.data, cdims) .+ reshape(HL.b.data, 1, 1, :, 1) diff --git a/src/layers/layer_residual_block.jl b/src/layers/layer_residual_block.jl index cf8c0b46..0cc2c630 100644 --- a/src/layers/layer_residual_block.jl +++ b/src/layers/layer_residual_block.jl @@ -58,13 +58,15 @@ or See also: [`get_params`](@ref), [`clear_grad!`](@ref) """ -struct ResidualBlock{S1, S2, P1, P2} <: NeuralNetLayer +struct ResidualBlock <: NeuralNetLayer W1::Parameter W2::Parameter W3::Parameter b1::Parameter b2::Parameter fan::Bool + strides + pad end @Flux.functor ResidualBlock @@ -84,7 +86,7 @@ function ResidualBlock(n_in, n_hidden; k1=3, k2=3, p1=1, p2=1, s1=1, s2=1, fan=f b1 = Parameter(zeros(Float32, n_hidden)) b2 = Parameter(zeros(Float32, n_hidden)) - return ResidualBlock{s1, s2, p1, p2}(W1, W2, W3, b1, b2, fan) + return ResidualBlock(W1, W2, W3, b1, b2, fan, (s1, s2), (p1, p2)) end # Constructor for given weights @@ -97,7 +99,7 @@ function ResidualBlock(W1, W2, W3, b1, b2; p1=1, p2=1, s1=1, s2=1, fan=false, nd b1 = Parameter(b1) b2 = Parameter(b2) - return ResidualBlock{s1, s2, p1, p2}(W1, W2, W3, b1, b2, fan) + return ResidualBlock(W1, W2, W3, b1, b2, fan, (s1, s2), (p1, p2)) end ResidualBlock3D(args...; kw...) = ResidualBlock(args...; kw..., ndims=3) @@ -105,16 +107,16 @@ ResidualBlock3D(args...; kw...) = ResidualBlock(args...; kw..., ndims=3) # Functions # Forward -function forward(X1::AbstractArray{Float32, N}, RB::ResidualBlock{S1,S2,P1,P2}; save=false) where {S1,S2,P1,P2,N} +function forward(X1::AbstractArray{Float32, N}, RB::ResidualBlock; save=false) where {N} inds =[i!=(N-1) ? 1 : (:) for i=1:N] - Y1 = conv(X1, RB.W1.data; stride=S1, pad=P1) .+ reshape(RB.b1.data, inds...) + Y1 = conv(X1, RB.W1.data; stride=RB.strides[1], pad=RB.pad[1]) .+ reshape(RB.b1.data, inds...) X2 = ReLU(Y1) - Y2 = X2 + conv(X2, RB.W2.data; stride=S2, pad=P2) .+ reshape(RB.b2.data, inds...) + Y2 = X2 + conv(X2, RB.W2.data; stride=RB.strides[2], pad=RB.pad[2]) .+ reshape(RB.b2.data, inds...) X3 = ReLU(Y2) - cdims3 = DCDims(X1, RB.W3.data; nc=2*size(X1, N-1), stride=S1, padding=P1) + cdims3 = DCDims(X1, RB.W3.data; nc=2*size(X1, N-1), stride=RB.strides[1], padding=RB.pad[1]) Y3 = ∇conv_data(X3, RB.W3.data, cdims3) RB.fan == true ? (X4 = ReLU(Y3)) : (X4 = GaLU(Y3)) @@ -127,7 +129,7 @@ end # Backward function backward(ΔX4::AbstractArray{Float32, N}, X1::AbstractArray{Float32, N}, - RB::ResidualBlock{S1,S2,P1,P2}; set_grad::Bool=true) where {S1,S2,P1,P2,N} + RB::ResidualBlock; set_grad::Bool=true) where {N} inds = [i!=(N-1) ? 1 : (:) for i=1:N] dims = collect(1:N-1); dims[end] +=1 @@ -135,8 +137,8 @@ function backward(ΔX4::AbstractArray{Float32, N}, X1::AbstractArray{Float32, N} Y1, Y2, Y3, X2, X3 = forward(X1, RB; save=true) # Cdims - cdims2 = DenseConvDims(X2, RB.W2.data; stride=S2, padding=P2) - cdims3 = DCDims(X1, RB.W3.data; nc=2*size(X1, N-1), stride=S1, padding=P1) + cdims2 = DenseConvDims(X2, RB.W2.data; stride=RB.strides[2], padding=RB.pad[2]) + cdims3 = DCDims(X1, RB.W3.data; nc=2*size(X1, N-1), stride=RB.strides[1], padding=RB.pad[1]) # Backpropagate residual ΔX4 and compute gradients RB.fan == true ? (ΔY3 = ReLUgrad(ΔX4, Y3)) : (ΔY3 = GaLUgrad(ΔX4, Y3)) @@ -148,7 +150,7 @@ function backward(ΔX4::AbstractArray{Float32, N}, X1::AbstractArray{Float32, N} ΔW2 = ∇conv_filter(X2, ΔY2, cdims2) Δb2 = sum(ΔY2, dims=dims)[inds...] - cdims1 = DenseConvDims(X1, RB.W1.data; stride=S1, padding=P1) + cdims1 = DenseConvDims(X1, RB.W1.data; stride=RB.strides[1], padding=RB.pad[1]) ΔY1 = ReLUgrad(ΔX2, Y1) ΔX1 = ∇conv_data(ΔY1, RB.W1.data, cdims1) @@ -171,24 +173,24 @@ end ## Jacobian-related functions function jacobian(ΔX1::AbstractArray{Float32, N}, Δθ::Array{Parameter, 1}, - X1::AbstractArray{Float32, N}, RB::ResidualBlock{S1,S2,P1,P2}) where {S1,S2,P1,P2,N} + X1::AbstractArray{Float32, N}, RB::ResidualBlock) where {N} inds = [i!=(N-1) ? 1 : (:) for i=1:N] # Cdims - cdims1 = DenseConvDims(X1, RB.W1.data; stride=S1, padding=P1) + cdims1 = DenseConvDims(X1, RB.W1.data; stride=RB.strides[1], padding=RB.pad[1]) Y1 = conv(X1, RB.W1.data, cdims1) .+ reshape(RB.b1.data, inds...) ΔY1 = conv(ΔX1, RB.W1.data, cdims1) + conv(X1, Δθ[1].data, cdims1) .+ reshape(Δθ[4].data, inds...) X2 = ReLU(Y1) ΔX2 = ReLUgrad(ΔY1, Y1) - cdims2 = DenseConvDims(X2, RB.W2.data; stride=S2, padding=P2) + cdims2 = DenseConvDims(X2, RB.W2.data; stride=RB.strides[2], padding=RB.pad[2]) Y2 = X2 + conv(X2, RB.W2.data, cdims2) .+ reshape(RB.b2.data, inds...) ΔY2 = ΔX2 + conv(ΔX2, RB.W2.data, cdims2) + conv(X2, Δθ[2].data, cdims2) .+ reshape(Δθ[5].data, inds...) X3 = ReLU(Y2) ΔX3 = ReLUgrad(ΔY2, Y2) - cdims3 = DCDims(X1, RB.W3.data; nc=2*size(X1, N-1), stride=S1, padding=P1) + cdims3 = DCDims(X1, RB.W3.data; nc=2*size(X1, N-1), stride=RB.strides[1], padding=RB.pad[1]) Y3 = ∇conv_data(X3, RB.W3.data, cdims3) ΔY3 = ∇conv_data(ΔX3, RB.W3.data, cdims3) + ∇conv_data(X3, Δθ[3].data, cdims3) if RB.fan == true