Skip to content

Commit

Permalink
Automatically chose eltype of OperatorConv weigths
Browse files Browse the repository at this point in the history
  • Loading branch information
tknopp committed Aug 20, 2022
1 parent 60b7726 commit 1a5b271
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 3 deletions.
1 change: 1 addition & 0 deletions src/Transform/chebyshev_transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ struct ChebyshevTransform{N, S} <: AbstractTransform
end

Base.ndims(::ChebyshevTransform{N}) where {N} = N
Base.eltype(::Type{ChebyshevTransform}) = Float32

function transform(t::ChebyshevTransform{N}, 𝐱::AbstractArray) where {N}
return FFTW.r2r(𝐱, FFTW.REDFT10, 1:N) # [size(x)..., in_chs, batch]
Expand Down
1 change: 1 addition & 0 deletions src/Transform/fourier_transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ struct FourierTransform{N, S} <: AbstractTransform
end

Base.ndims(::FourierTransform{N}) where {N} = N
Base.eltype(::Type{FourierTransform}) = ComplexF32

function transform(ft::FourierTransform, 𝐱::AbstractArray)
return fft(Zygote.hook(real, 𝐱), 1:ndims(ft)) # [size(x)..., in_chs, batch]
Expand Down
8 changes: 5 additions & 3 deletions src/operator_kernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ end

"""
OperatorConv(ch, modes, transform;
init=c_glorot_uniform, permuted=false, T=ComplexF32)
init=glorot_uniform, permuted=false, T=ComplexF32)
## Arguments
Expand Down Expand Up @@ -49,9 +49,9 @@ OperatorConv(2 => 5, (16,), FourierTransform, permuted=true)
function OperatorConv(ch::Pair{S, S},
modes::NTuple{N, S},
Transform::Type{<:AbstractTransform};
init = c_glorot_uniform,
init = (dims...) -> Flux.glorot_uniform(eltype(Transform), dims...),
permuted = false,
T::DataType = ComplexF32) where {S <: Integer, N}
T::DataType = eltype(Transform)) where {S <: Integer, N}
in_chs, out_chs = ch
scale = one(T) / (in_chs * out_chs)
weights = scale * init(prod(modes), in_chs, out_chs)
Expand Down Expand Up @@ -185,6 +185,8 @@ end
#########

c_glorot_uniform(dims...) = Flux.glorot_uniform(dims...) + Flux.glorot_uniform(dims...) * im
Flux.glorot_uniform(::Type{<:Real}, dims...) = Flux.glorot_uniform(dims...)
Flux.glorot_uniform(::Type{<:Complex}, dims...) = c_glorot_uniform(dims...)

# [prod(modes), out_chs, batch] <- [prod(modes), in_chs, batch] * [out_chs, in_chs, prod(modes)]
einsum(𝐱₁, 𝐱₂) = @tullio 𝐲[m, o, b] := 𝐱₁[m, i, b] * 𝐱₂[m, i, o]
Expand Down

0 comments on commit 1a5b271

Please sign in to comment.