Skip to content

Commit

Permalink
Merge pull request #23 from slimgroup/functorfix
Browse files Browse the repository at this point in the history
Fix hyperbolic and residual block functors
  • Loading branch information
mloubout authored Mar 3, 2021
2 parents aac8985 + c159a96 commit e22195d
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 38 deletions.
49 changes: 26 additions & 23 deletions src/layers/invertible_layer_hyperbolic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,13 @@ Create an invertible hyperbolic coupling layer.
See also: [`get_params`](@ref), [`clear_grad!`](@ref)
"""
struct HyperbolicLayer{S, P, A} <: NeuralNetLayer
struct HyperbolicLayer <: NeuralNetLayer
W::Parameter
b::Parameter
α::Float32
stride::Int
pad::Int
action
end

@Flux.functor HyperbolicLayer
Expand All @@ -75,7 +78,7 @@ function HyperbolicLayer(n_in::Int64, kernel::Int64, stride::Int64,
W = Parameter(glorot_uniform(k..., n_out, n_hidden))
b = Parameter(zeros(Float32, n_hidden))

return HyperbolicLayer{stride, pad, action}(W, b, α)
return HyperbolicLayer(W, b, α,stride, pad, action)
end

HyperbolicLayer3D(args...; kw...) = HyperbolicLayer(args...; kw..., ndims=3)
Expand All @@ -91,7 +94,7 @@ function HyperbolicLayer(W::AbstractArray{Float32, N}, b::AbstractArray{Float32,
W = Parameter(W)
b = Parameter(b)

return HyperbolicLayer{stride, pad, action}(W, b, α)
return HyperbolicLayer(W, b, α, stride, pad, action)
end

HyperbolicLayer3D(W::AbstractArray{Float32, N}, b, stride, pad;
Expand All @@ -101,24 +104,24 @@ HyperbolicLayer3D(W::AbstractArray{Float32, N}, b, stride, pad;
#################################################

# Forward pass
function forward(X_prev_in, X_curr_in, HL::HyperbolicLayer{s, p, a}) where {s, p, a}
function forward(X_prev_in, X_curr_in, HL::HyperbolicLayer)

# Change dimensions
if a == 0
if HL.action == 0
X_prev = identity(X_prev_in)
X_curr = identity(X_curr_in)
elseif a == 1
elseif HL.action == 1
X_prev = wavelet_unsqueeze(X_prev_in)
X_curr = wavelet_unsqueeze(X_curr_in)
elseif a == -1
elseif HL.action == -1
X_prev = wavelet_squeeze(X_prev_in)
X_curr = wavelet_squeeze(X_curr_in)
else
throw("Specified operation not defined.")
end

# Symmetric convolution w/ relu activation
cdims = DCDims(X_curr, HL.W.data; stride=s, padding=p)
cdims = DCDims(X_curr, HL.W.data; stride=HL.stride, padding=HL.pad)
if length(size(X_curr)) == 4
X_conv = conv(X_curr, HL.W.data, cdims) .+ reshape(HL.b.data, 1, 1, :, 1)
else
Expand All @@ -134,8 +137,8 @@ function forward(X_prev_in, X_curr_in, HL::HyperbolicLayer{s, p, a}) where {s, p
end

# Inverse pass
function inverse(X_curr, X_new, HL::HyperbolicLayer{s, p, a}; save=false) where {s, p, a}
cdims = DCDims(X_curr, HL.W.data; stride=s, padding=p)
function inverse(X_curr, X_new, HL::HyperbolicLayer; save=false)
cdims = DCDims(X_curr, HL.W.data; stride=HL.stride, padding=HL.pad)
# Symmetric convolution w/ relu activation
if length(size(X_curr)) == 4
X_conv = conv(X_curr, HL.W.data, cdims) .+ reshape(HL.b.data, 1, 1, :, 1)
Expand All @@ -149,13 +152,13 @@ function inverse(X_curr, X_new, HL::HyperbolicLayer{s, p, a}; save=false) where
X_prev = 2*X_curr - X_new + HL.α*X_convT

# Change dimensions
if a == 0
if HL.action == 0
X_prev_in = identity(X_prev)
X_curr_in = identity(X_curr)
elseif a == -1
elseif HL.action == -1
X_prev_in = wavelet_unsqueeze(X_prev)
X_curr_in = wavelet_unsqueeze(X_curr)
elseif a == 1
elseif HL.action == 1
X_prev_in = wavelet_squeeze(X_prev)
X_curr_in = wavelet_squeeze(X_curr)
else
Expand All @@ -170,13 +173,13 @@ function inverse(X_curr, X_new, HL::HyperbolicLayer{s, p, a}; save=false) where
end

# Backward pass
function backward(ΔX_curr, ΔX_new, X_curr, X_new, HL::HyperbolicLayer{s, p, a}; set_grad::Bool=true) where {s, p, a}
function backward(ΔX_curr, ΔX_new, X_curr, X_new, HL::HyperbolicLayer; set_grad::Bool=true)

# Recompute forward states
X_prev_in, X_curr_in, X_conv, X_relu = inverse(X_curr, X_new, HL; save=true)

# Backpropagate data residual and compute gradients
cdims = DCDims(X_curr, HL.W.data; stride=s, padding=p)
cdims = DCDims(X_curr, HL.W.data; stride=HL.stride, padding=HL.pad)
ΔX_convT = copy(ΔX_new)
ΔX_relu = -HL.α*conv(ΔX_convT, HL.W.data, cdims)
ΔW = -HL.α*∇conv_filter(ΔX_convT, X_relu, cdims)
Expand All @@ -201,13 +204,13 @@ function backward(ΔX_curr, ΔX_new, X_curr, X_new, HL::HyperbolicLayer{s, p, a}
end

# Change dimensions
if a == 0
if HL.action == 0
ΔX_prev_in = identity(ΔX_prev)
ΔX_curr_in = identity(ΔX_curr)
elseif a == -1
elseif HL.action == -1
ΔX_prev_in = wavelet_unsqueeze(ΔX_prev)
ΔX_curr_in = wavelet_unsqueeze(ΔX_curr)
elseif a == 1
elseif HL.action == 1
ΔX_prev_in = wavelet_squeeze(ΔX_prev)
ΔX_curr_in = wavelet_squeeze(ΔX_curr)
else
Expand All @@ -221,20 +224,20 @@ end
## Jacobian utilities

# 2D
function jacobian(ΔX_prev_in, ΔX_curr_in, Δθ, X_prev_in, X_curr_in, HL::HyperbolicLayer{s, p, a}) where {s, p, a}
function jacobian(ΔX_prev_in, ΔX_curr_in, Δθ, X_prev_in, X_curr_in, HL::HyperbolicLayer)

# Change dimensions
if a == 0
if HL.action == 0
X_prev = identity(X_prev_in)
X_curr = identity(X_curr_in)
ΔX_prev = identity(ΔX_prev_in)
ΔX_curr = identity(ΔX_curr_in)
elseif a == 1
elseif HL.action == 1
X_prev = wavelet_unsqueeze(X_prev_in)
X_curr = wavelet_unsqueeze(X_curr_in)
ΔX_prev = wavelet_unsqueeze(ΔX_prev_in)
ΔX_curr = wavelet_unsqueeze(ΔX_curr_in)
elseif a == -1
elseif HL.action == -1
X_prev = wavelet_squeeze(X_prev_in)
X_curr = wavelet_squeeze(X_curr_in)
ΔX_prev = wavelet_squeeze(ΔX_prev_in)
Expand All @@ -243,7 +246,7 @@ function jacobian(ΔX_prev_in, ΔX_curr_in, Δθ, X_prev_in, X_curr_in, HL::Hype
throw("Specified operation not defined.")
end

cdims = DCDims(X_curr, HL.W.data; stride=s, padding=p)
cdims = DCDims(X_curr, HL.W.data; stride=HL.stride, padding=HL.pad)
# Symmetric convolution w/ relu activation
if length(size(X_curr)) == 4
X_conv = conv(X_curr, HL.W.data, cdims) .+ reshape(HL.b.data, 1, 1, :, 1)
Expand Down
32 changes: 17 additions & 15 deletions src/layers/layer_residual_block.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,15 @@ or
See also: [`get_params`](@ref), [`clear_grad!`](@ref)
"""
struct ResidualBlock{S1, S2, P1, P2} <: NeuralNetLayer
struct ResidualBlock <: NeuralNetLayer
W1::Parameter
W2::Parameter
W3::Parameter
b1::Parameter
b2::Parameter
fan::Bool
strides
pad
end

@Flux.functor ResidualBlock
Expand All @@ -84,7 +86,7 @@ function ResidualBlock(n_in, n_hidden; k1=3, k2=3, p1=1, p2=1, s1=1, s2=1, fan=f
b1 = Parameter(zeros(Float32, n_hidden))
b2 = Parameter(zeros(Float32, n_hidden))

return ResidualBlock{s1, s2, p1, p2}(W1, W2, W3, b1, b2, fan)
return ResidualBlock(W1, W2, W3, b1, b2, fan, (s1, s2), (p1, p2))
end

# Constructor for given weights
Expand All @@ -97,24 +99,24 @@ function ResidualBlock(W1, W2, W3, b1, b2; p1=1, p2=1, s1=1, s2=1, fan=false, nd
b1 = Parameter(b1)
b2 = Parameter(b2)

return ResidualBlock{s1, s2, p1, p2}(W1, W2, W3, b1, b2, fan)
return ResidualBlock(W1, W2, W3, b1, b2, fan, (s1, s2), (p1, p2))
end

ResidualBlock3D(args...; kw...) = ResidualBlock(args...; kw..., ndims=3)
#######################################################################################################################
# Functions

# Forward
function forward(X1::AbstractArray{Float32, N}, RB::ResidualBlock{S1,S2,P1,P2}; save=false) where {S1,S2,P1,P2,N}
function forward(X1::AbstractArray{Float32, N}, RB::ResidualBlock; save=false) where {N}
inds =[i!=(N-1) ? 1 : (:) for i=1:N]

Y1 = conv(X1, RB.W1.data; stride=S1, pad=P1) .+ reshape(RB.b1.data, inds...)
Y1 = conv(X1, RB.W1.data; stride=RB.strides[1], pad=RB.pad[1]) .+ reshape(RB.b1.data, inds...)
X2 = ReLU(Y1)

Y2 = X2 + conv(X2, RB.W2.data; stride=S2, pad=P2) .+ reshape(RB.b2.data, inds...)
Y2 = X2 + conv(X2, RB.W2.data; stride=RB.strides[2], pad=RB.pad[2]) .+ reshape(RB.b2.data, inds...)
X3 = ReLU(Y2)

cdims3 = DCDims(X1, RB.W3.data; nc=2*size(X1, N-1), stride=S1, padding=P1)
cdims3 = DCDims(X1, RB.W3.data; nc=2*size(X1, N-1), stride=RB.strides[1], padding=RB.pad[1])
Y3 = ∇conv_data(X3, RB.W3.data, cdims3)
RB.fan == true ? (X4 = ReLU(Y3)) : (X4 = GaLU(Y3))

Expand All @@ -127,16 +129,16 @@ end

# Backward
function backward(ΔX4::AbstractArray{Float32, N}, X1::AbstractArray{Float32, N},
RB::ResidualBlock{S1,S2,P1,P2}; set_grad::Bool=true) where {S1,S2,P1,P2,N}
RB::ResidualBlock; set_grad::Bool=true) where {N}
inds = [i!=(N-1) ? 1 : (:) for i=1:N]
dims = collect(1:N-1); dims[end] +=1

# Recompute forward states from input X
Y1, Y2, Y3, X2, X3 = forward(X1, RB; save=true)

# Cdims
cdims2 = DenseConvDims(X2, RB.W2.data; stride=S2, padding=P2)
cdims3 = DCDims(X1, RB.W3.data; nc=2*size(X1, N-1), stride=S1, padding=P1)
cdims2 = DenseConvDims(X2, RB.W2.data; stride=RB.strides[2], padding=RB.pad[2])
cdims3 = DCDims(X1, RB.W3.data; nc=2*size(X1, N-1), stride=RB.strides[1], padding=RB.pad[1])

# Backpropagate residual ΔX4 and compute gradients
RB.fan == true ? (ΔY3 = ReLUgrad(ΔX4, Y3)) : (ΔY3 = GaLUgrad(ΔX4, Y3))
Expand All @@ -148,7 +150,7 @@ function backward(ΔX4::AbstractArray{Float32, N}, X1::AbstractArray{Float32, N}
ΔW2 = ∇conv_filter(X2, ΔY2, cdims2)
Δb2 = sum(ΔY2, dims=dims)[inds...]

cdims1 = DenseConvDims(X1, RB.W1.data; stride=S1, padding=P1)
cdims1 = DenseConvDims(X1, RB.W1.data; stride=RB.strides[1], padding=RB.pad[1])

ΔY1 = ReLUgrad(ΔX2, Y1)
ΔX1 = ∇conv_data(ΔY1, RB.W1.data, cdims1)
Expand All @@ -171,24 +173,24 @@ end

## Jacobian-related functions
function jacobian(ΔX1::AbstractArray{Float32, N}, Δθ::Array{Parameter, 1},
X1::AbstractArray{Float32, N}, RB::ResidualBlock{S1,S2,P1,P2}) where {S1,S2,P1,P2,N}
X1::AbstractArray{Float32, N}, RB::ResidualBlock) where {N}
inds = [i!=(N-1) ? 1 : (:) for i=1:N]
# Cdims
cdims1 = DenseConvDims(X1, RB.W1.data; stride=S1, padding=P1)
cdims1 = DenseConvDims(X1, RB.W1.data; stride=RB.strides[1], padding=RB.pad[1])

Y1 = conv(X1, RB.W1.data, cdims1) .+ reshape(RB.b1.data, inds...)
ΔY1 = conv(ΔX1, RB.W1.data, cdims1) + conv(X1, Δθ[1].data, cdims1) .+ reshape(Δθ[4].data, inds...)
X2 = ReLU(Y1)
ΔX2 = ReLUgrad(ΔY1, Y1)

cdims2 = DenseConvDims(X2, RB.W2.data; stride=S2, padding=P2)
cdims2 = DenseConvDims(X2, RB.W2.data; stride=RB.strides[2], padding=RB.pad[2])

Y2 = X2 + conv(X2, RB.W2.data, cdims2) .+ reshape(RB.b2.data, inds...)
ΔY2 = ΔX2 + conv(ΔX2, RB.W2.data, cdims2) + conv(X2, Δθ[2].data, cdims2) .+ reshape(Δθ[5].data, inds...)
X3 = ReLU(Y2)
ΔX3 = ReLUgrad(ΔY2, Y2)

cdims3 = DCDims(X1, RB.W3.data; nc=2*size(X1, N-1), stride=S1, padding=P1)
cdims3 = DCDims(X1, RB.W3.data; nc=2*size(X1, N-1), stride=RB.strides[1], padding=RB.pad[1])
Y3 = ∇conv_data(X3, RB.W3.data, cdims3)
ΔY3 = ∇conv_data(ΔX3, RB.W3.data, cdims3) + ∇conv_data(X3, Δθ[3].data, cdims3)
if RB.fan == true
Expand Down

0 comments on commit e22195d

Please sign in to comment.