From 1a5b271ee19d969d7c9fc9c486f0cec2105e33da Mon Sep 17 00:00:00 2001 From: Tobias Knopp Date: Sat, 20 Aug 2022 22:19:35 +0200 Subject: [PATCH] Automatically chose eltype of OperatorConv weigths --- src/Transform/chebyshev_transform.jl | 1 + src/Transform/fourier_transform.jl | 1 + src/operator_kernel.jl | 8 +++++--- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/Transform/chebyshev_transform.jl b/src/Transform/chebyshev_transform.jl index 4ec783d7..a6c0191d 100644 --- a/src/Transform/chebyshev_transform.jl +++ b/src/Transform/chebyshev_transform.jl @@ -5,6 +5,7 @@ struct ChebyshevTransform{N, S} <: AbstractTransform end Base.ndims(::ChebyshevTransform{N}) where {N} = N +Base.eltype(::Type{ChebyshevTransform}) = Float32 function transform(t::ChebyshevTransform{N}, 𝐱::AbstractArray) where {N} return FFTW.r2r(𝐱, FFTW.REDFT10, 1:N) # [size(x)..., in_chs, batch] diff --git a/src/Transform/fourier_transform.jl b/src/Transform/fourier_transform.jl index ed040f96..09e581b3 100644 --- a/src/Transform/fourier_transform.jl +++ b/src/Transform/fourier_transform.jl @@ -5,6 +5,7 @@ struct FourierTransform{N, S} <: AbstractTransform end Base.ndims(::FourierTransform{N}) where {N} = N +Base.eltype(::Type{FourierTransform}) = ComplexF32 function transform(ft::FourierTransform, 𝐱::AbstractArray) return fft(Zygote.hook(real, 𝐱), 1:ndims(ft)) # [size(x)..., in_chs, batch] diff --git a/src/operator_kernel.jl b/src/operator_kernel.jl index fa0a2eea..d131ad34 100644 --- a/src/operator_kernel.jl +++ b/src/operator_kernel.jl @@ -19,7 +19,7 @@ end """ OperatorConv(ch, modes, transform; - init=c_glorot_uniform, permuted=false, T=ComplexF32) + init=glorot_uniform, permuted=false, T=ComplexF32) ## Arguments @@ -49,9 +49,9 @@ OperatorConv(2 => 5, (16,), FourierTransform, permuted=true) function OperatorConv(ch::Pair{S, S}, modes::NTuple{N, S}, Transform::Type{<:AbstractTransform}; - init = c_glorot_uniform, + init = (dims...) -> Flux.glorot_uniform(eltype(Transform), dims...), permuted = false, - T::DataType = ComplexF32) where {S <: Integer, N} + T::DataType = eltype(Transform)) where {S <: Integer, N} in_chs, out_chs = ch scale = one(T) / (in_chs * out_chs) weights = scale * init(prod(modes), in_chs, out_chs) @@ -185,6 +185,8 @@ end ######### c_glorot_uniform(dims...) = Flux.glorot_uniform(dims...) + Flux.glorot_uniform(dims...) * im +Flux.glorot_uniform(::Type{<:Real}, dims...) = Flux.glorot_uniform(dims...) +Flux.glorot_uniform(::Type{<:Complex}, dims...) = c_glorot_uniform(dims...) # [prod(modes), out_chs, batch] <- [prod(modes), in_chs, batch] * [out_chs, in_chs, prod(modes)] einsum(𝐱₁, 𝐱₂) = @tullio 𝐲[m, o, b] := 𝐱₁[m, i, b] * 𝐱₂[m, i, o]