Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tied weights using Flux layers #1592

Open
dfenn opened this issue May 7, 2021 · 9 comments
Open

Tied weights using Flux layers #1592

dfenn opened this issue May 7, 2021 · 9 comments

Comments

@dfenn
Copy link

dfenn commented May 7, 2021

I'm trying to build an autoencoder that uses both conv and dense layers, and I'd like to tie the weights. #488 demonstrates how to do this for dense layers by not using the Flux Dense type and instead using the encoder's weights directly.

Is there a way to accomplish something similar while still using Flux-defined layer types, such as Conv? I've tried manually setting the decoder parameters in the loss function; something like this:


mutable struct AE_tied
    encoder
    decoder

    weights_encoder
    weights_decoder
end

AE_tied(encoder, decoder) = AE_tied(encoder, decoder, params(encoder), params(decoder))

function (a::AE_tied)(x)
    x = a.encoder(x)
    a.weights_decoder[1] .= a.weights_encoder[1]
    a.decoder(x)
end

encoder = Conv((3,3), 1=>2, relu, pad=SamePad())
decoder = ConvTranspose((3,3), 2=>1, relu, pad=SamePad())

model = AE_tied(encoder, decoder)
model = cpu(model)

ps = Flux.params(model.encoder)
opt = ADAM(0.1)

function loss(x) 
    y = model(x)
    sum((y .- x) .^2) / length(x)
end

train_data = cpu(rand(5, 5, 1, 2))

for epoch in 1:1
    local trainLoss
    gs = Flux.gradient(ps) do
        trainLoss = loss(train_data)
        return trainLoss
    end
    Flux.Optimise.update!(opt, ps, gs)
    @show trainLoss
end

Running this gives ERROR: LoadError: Mutating arrays is not supported. It's the line a.weights_decoder[1] .= a.weights_encoder[1] that's the issue.

Am I going about this the wrong way, or is what I'm trying to do not supported? Thanks in advance for any help

@atiyo
Copy link
Contributor

atiyo commented May 9, 2021

Indeed mutating isn't supported by Zygote, which is used to calculate the gradients. It is supported in some other Julia AD packages which you might be able to use.

However, I don't believe the above snippet actually does tie the weights properly. E.g. the gradients of tied weights should be the same, but this won't be the case if you tie them by manually tweaking them to be equal.

With this in mind, my preferred solution would be to initialise the weights of the decoder to be a @view on the weights of the encoder.

I haven't actually checked to see whether this plays nicely with Flux, but maybe it's something to try.

@dfenn
Copy link
Author

dfenn commented May 11, 2021

Thanks for you response. I was able to get it working using @views for the convolutional layers. However, the same approach isn't working for dense layers, where the weights matrix must be transposed:

encoder = Dense(5, 2)
@views decoder = Dense(transpose(encoder.weight), rand(5))

This gives the error

ERROR: LoadError: TypeError: in typeassert, expected Tuple{Transpose{Float32, Matrix{Float32}}, Transpose{Float32, Matrix{Float32}}, Vector{Float64}}, got a value of type Tuple{Matrix{Float32}, Matrix{Float32}, Vector{Float64}}
Stacktrace:
 [1] apply!(o::ADAM, x::Transpose{Float32, Matrix{Float32}}, Δ::Matrix{Float64})
   @ Flux.Optimise ~/.julia/packages/Flux/6BByF/src/optimise/optimisers.jl:175
 [2] update!(opt::ADAM, x::Transpose{Float32, Matrix{Float32}}, x̄::Matrix{Float64})
   @ Flux.Optimise ~/.julia/packages/Flux/6BByF/src/optimise/train.jl:23
 [3] update!(opt::ADAM, xs::Params, gs::Zygote.Grads)
   @ Flux.Optimise ~/.julia/packages/Flux/6BByF/src/optimise/train.jl:29

It looks like Flux is inferring the type as Transpose and then complaining when it receives a Matrix. I've tried using PermutedDimsArray instead, with similar results.

It's not clear to me how to address this. Any ideas?

@darsnack
Copy link
Member

We should probably change that line in the Adam code to use Adapt.jl to get the correct type instead of hard-typing the return of get!.

@DhairyaLGandhi
Copy link
Member

DhairyaLGandhi commented May 11, 2021

Probably better to incorporate directly in optimisers.jl

As long as we pass in the correct references we should be good. I don't think it needs to be addressed in the optimisers otherwise.

@darsnack
Copy link
Member

I don't think we need the fix in Optimisers.jl because the state is initialized separately (and correctly). This appears to only be a bug for IdDict optimizers.

Agreed that we only need the references to be correct.

@CarloLucibello
Copy link
Member

CarloLucibello commented Jun 10, 2021

possibly related to FluxML/Zygote.jl#991 and #1613

We should probably change that line in the Adam code to use Adapt.jl to get the correct type instead of hard-typing the return of get!.

Even if use something like #1613 to adapt the types, that wouldn't still be entirely correct because we would be taking 2 steps of adam with separate gradients instead of a single step with the accumulated one

@darsnack
Copy link
Member

darsnack commented Jun 23, 2021

taking 2 steps of adam with separate gradients instead of a single step with the accumulated one

Yeah, with ADAM this will certainly be wrong. Referencing FluxML/Zygote.jl#991 (comment), it's not two steps that's wrong. It's the momentum terms that will be incorrect leading to two steps not being equivalent to a single accumulated one. For simpler optimizers like Descent, this will be correct (assuming the gradients are correct which they are for explicit params).

@mleprovost
Copy link

Hello,

I wanted to follow-up on this issue. Is it resolved in the lastest version of Flux.jl?

@mcabbott
Copy link
Member

On latest Flux, using new-style training with setup, something like dec = Dense(transpose(encoder.weight)) should just work. It will see through the transpose and notice that the same array appears twice.

(With old-style IdDict optimisers, I'm not sure.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Development

No branches or pull requests

7 participants