Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DenseConvDims not always type stable #274

Closed
ghost opened this issue Jan 27, 2021 · 8 comments
Closed

DenseConvDims not always type stable #274

ghost opened this issue Jan 27, 2021 · 8 comments

Comments

@ghost
Copy link

ghost commented Jan 27, 2021

The following code example shows that the type of DenseConvDims cannot be properly inferred when it's input types are all NTuples. When changing one of the inputs to an array it can for some reason be correctly inferred.

This seems to be the source of FluxML/Flux.jl#1350. Maybe linked to #125?

using NNlib
using InteractiveUtils

function main()
    array = [1,1,1,1]::Array{Int64, 1}
    ntuple = (1,1,1,1)::NTuple{4, Int64}

    # type stable
    @code_warntype NNlib.DenseConvDims(array, ntuple)
    @code_warntype NNlib.DenseConvDims(ntuple, array)
    
    # type unstable
    @code_warntype NNlib.DenseConvDims(ntuple, ntuple)
end
main()
Variables
  #self#::Type{DenseConvDims}
  x_size::NTuple{4,Int64}
  w_size::NTuple{4,Int64}

Body::DenseConvDims{2,_A,_B,_C,_D,_E,_F,_G} where _G where _F where _E where _D where _C where _B where _A
1 ─ %1 = NNlib.:(var"#DenseConvDims#6")(1, 0, 1, false, #self#, x_size, w_size)::DenseConvDims{2,_A,_B,_C,_D,_E,_F,_G} where _G where _F where _E where _D where _C where _B where _A
└──      return %1
@CarloLucibello
Copy link
Member

glad you found this, should be easy to fix

@CarloLucibello
Copy link
Member

Actually I don't think the constructor can be made type stable, since the type parameter is a runtime input value:

julia> struct A{N}; x::Int; end

julia> A(i) = A{i}(10)
A

julia> @code_warntype A(1)
Variables
  #self#::Type{A}
  i::Int64

Body::A{_A} where _A
1%1 = Core.apply_type(Main.A, i)::Type{A{_A}} where _A
│   %2 = (%1)(10)::A{_A} where _A
└──      return %2

@joostveenema why do you think this is related to FluxML/Flux.jl#1350?

@CarloLucibello
Copy link
Member

When changing one of the inputs to an array it can for some reason be correctly inferred.

they are not inferred, the call just throws an error

@ghost
Copy link
Author

ghost commented Feb 3, 2021

You are right, the code example does not make sense. I thought I had simplified the problem. Let's try again.

In FluxML/Flux.jl#1350 the problem seems to be that the return type of NNlib.channels_in is Any in some cases. In that case the return type of the similar call becomes Union{CuArray, OffsetArray} if the OffsetArrays package is loaded (or any other package that has it as a dependency). Not 100% sure why, but the crash there only seems to occur when the input type of tanh is inferred as an Union.

Compare the following to snippets, in the first one the type of DenseConvDims is more concrete and the return type of channels_in is inferred (in both cases the type parameters are run time input):

using NNlib
using InteractiveUtils

function example(W::AbstractArray)
  cdims = DenseConvDims(
    (1, 1, 1, 1),
    (size(W)[1:end-1]..., 1),
  )
  
  NNlib.channels_in(cdims)
end

x = rand(Float32, 1, 1, 1, 1)

@code_warntype example(x)
Variables
  #self#::Core.Compiler.Const(example, false)
  W::Array{Float32,4}
  cdims::DenseConvDims{2,_A,1,1,(1, 1),_B,(1, 1),false} where _B where _A

Body::Int64
1 ─ %1  = Core.tuple(1, 1, 1, 1)::Core.Compiler.Const((1, 1, 1, 1), false)
│   %2  = Main.size(W)::NTuple{4,Int64}
│   %3  = Base.lastindex(%2)::Core.Compiler.Const(4, false)
│   %4  = (%3 - 1)::Core.Compiler.Const(3, false)
│   %5  = (1:%4)::Core.Compiler.Const(1:3, false)
│   %6  = Base.getindex(%2, %5)::Tuple{Int64,Int64,Int64}
│   %7  = Core.tuple(1)::Core.Compiler.Const((1,), false)
│   %8  = Core._apply_iterate(Base.iterate, Core.tuple, %6, %7)::Core.Compiler.PartialStruct(NTuple{4,Int64}, Any[Int64, Int64, Int64, Core.Compiler.Const(1, false)])
│         (cdims = Main.DenseConvDims(%1, %8))
│   %10 = NNlib.channels_in::Core.Compiler.Const(NNlib.channels_in, false)
│   %11 = (%10)(cdims)::Core.Compiler.Const(1, false)     <----- return type and value known
└──       return %11

vs

using NNlib
using InteractiveUtils

function example(W::AbstractArray)
  cdims = DenseConvDims(
    (1, 1, 1, 1),
    size(W),                              # <---- This changed
  )
  
  NNlib.channels_in(cdims)
end

x = rand(Float32, 1, 1, 1, 1)

@code_warntype example(x)
Variables
  #self#::Core.Compiler.Const(example, false)
  W::Array{Float32,4}
  cdims::DenseConvDims{2,_A,_B,_C,_D,_E,_F,_G} where _G where _F where _E where _D where _C where _B where _A

Body::Any
1 ─ %1 = Core.tuple(1, 1, 1, 1)::Core.Compiler.Const((1, 1, 1, 1), false)
│   %2 = Main.size(W)::NTuple{4,Int64}
│        (cdims = Main.DenseConvDims(%1, %2))
│   %4 = NNlib.channels_in::Core.Compiler.Const(NNlib.channels_in, false)
│   %5 = (%4)(cdims)::Any   <-------- return type unknown
└──      return %5

I am not that familiar with the Julia internals, so I might be wrong here..

@DhairyaLGandhi
Copy link
Member

I believe it's because of the splat.

@CarloLucibello
Copy link
Member

Forcing the output type in the channels_in definition

channels_in(c::DenseConvDims{N,K,C_in,C_out}) where {N,K,C_in,C_out} = C_in::Int

it is still able to infer return type and value in the first example, while in the second example now the type is inferred:

julia> @code_warntype example(x)
Variables
  #self#::Core.Compiler.Const(example, false)
  W::Array{Float32,4}
  cdims::DenseConvDims{2,_A,_B,_C,_D,_E,_F,_G} where _G where _F where _E where _D where _C where _B where _A

Body::Int64
1%1 = Core.tuple(1, 1, 1, 1)::Core.Compiler.Const((1, 1, 1, 1), false)
│   %2 = Main.size(W)::NTuple{4,Int64}
│        (cdims = Main.DenseConvDims(%1, %2))
│   %4 = NNlib.channels_in::Core.Compiler.Const(NNlib.channels_in, false)
│   %5 = (%4)(cdims)::Int64
└──      return %5

Is this enough to fix FluxML/Flux.jl#1350?

Maybe we should revisit the ConvDims design, it feels like we are abusing parameterization and putting a lot of stress on the inference engine

@ghost
Copy link
Author

ghost commented Feb 4, 2021

Yes! This solves the crashes for me. I put the change I tested with in this PR #275.

@CarloLucibello
Copy link
Member

closed in #275

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants