Skip to content
This repository has been archived by the owner on Sep 28, 2024. It is now read-only.

Commit

Permalink
Merge pull request #7 from foldfelis/gpu
Browse files Browse the repository at this point in the history
support GPU
  • Loading branch information
foldfelis authored Aug 11, 2021
2 parents 6a390d9 + 1817c9d commit 609e6dd
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 23 deletions.
6 changes: 6 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,25 @@ authors = ["JingYu Ning <[email protected]> and contributors"]
version = "0.1.0"

[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
CUDAKernels = "72cfdca4-0801-4ab0-bf6a-d52aa10adc57"
DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
Fetch = "bb354801-46f6-40b6-9c3d-d42d7a74c775"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
MAT = "23992714-dd62-5051-b70f-ba57cb901cac"
Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
CUDA = "3.3"
CUDAKernels = "0.3"
DataDeps = "0.7"
FFTW = "1.4"
Fetch = "0.1"
Flux = "0.12"
KernelAbstractions = "0.7"
MAT = "0.10"
Tullio = "0.3"
Zygote = "0.6"
Expand Down
2 changes: 1 addition & 1 deletion docs/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"

[[NeuralOperators]]
path = ".."
uuid = "9ab867d4-5049-4b07-85bc-95379d8d6d9c"
uuid = "ea5c82af-86e5-48da-8ee1-382d6ad7af4b"
version = "0.1.0"

[[Parsers]]
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[deps]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
NeuralOperators = "9ab867d4-5049-4b07-85bc-95379d8d6d9c"
NeuralOperators = "ea5c82af-86e5-48da-8ee1-382d6ad7af4b"
19 changes: 10 additions & 9 deletions example/burgers.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
using NeuralOperators
using Flux
# using CUDA
using CUDA

# if has_cuda()
# @info "CUDA is on"
# device = gpu
# CUDA.allowscalar(false)
# else
if has_cuda()
@info "CUDA is on"
device = gpu
CUDA.allowscalar(false)
else
device = cpu
# end
end

m = FourierNeuralOperator() |> device
loss(𝐱, 𝐲) = sum(abs2, 𝐲 .- m(𝐱)) / size(𝐱)[end]
Expand All @@ -24,13 +24,14 @@ n_test = 40
loader_test = Flux.DataLoader((𝐱_test, 𝐲_test), batchsize=20, shuffle=false)

function loss_test()
l = 0
l = 0f0
for (𝐱, 𝐲) in loader_test
𝐱, 𝐲 = device(𝐱), device(𝐲)
l += loss(𝐱, 𝐲)
end
@info "loss: $(l/length(loader_test))"
end

data = [(𝐱, 𝐲) for (𝐱, 𝐲) in loader_train] |> device
opt = Flux.Optimiser(WeightDecay(1f-4), Flux.ADAM(1f-3))
Flux.@epochs 500 @time(Flux.train!(loss, params(m), data, opt, cb=Flux.throttle(loss_test, 10)))
Flux.@epochs 500 @time(Flux.train!(loss, params(m), data, opt, cb=Flux.throttle(loss_test, 5)))
3 changes: 3 additions & 0 deletions src/NeuralOperators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ module NeuralOperators
using Flux
using FFTW
using Tullio
using CUDA
using CUDAKernels
using KernelAbstractions
using Zygote

function __init__()
Expand Down
28 changes: 16 additions & 12 deletions src/fourier.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
export
SpectralConv1d,
FourierOperator,
FNO
FourierOperator

struct SpectralConv1d{T, S}
weight::T
Expand All @@ -28,26 +27,31 @@ function SpectralConv1d(

return Chain(
x -> Zygote.hook(real, x),
SpectralConv1d(weights, in_chs, out_chs, modes, σ)
SpectralConv1d(weights, in_chs, out_chs, modes, σ),
)
end

Flux.@functor SpectralConv1d

t(𝐱) = @tullio 𝐱ᵀ[i, j, k] := 𝐱[j, i, k]
ein_mul(𝐱₁, 𝐱₂) = @tullio 𝐲[m, o, b] := 𝐱₁[m, i, b] * 𝐱₂[o, i, m]

function (m::SpectralConv1d)(𝐱::AbstractArray)
𝐱_fft = fft(𝐱, 2) # [in_chs, x, batch]
𝐱_selected = 𝐱_fft[:, 1:m.modes, :] # [in_chs, modes, batch]
𝐱ᵀ = t(𝐱) # [x, in_chs, batch] <- [in_chs, x, batch]
𝐱_fft = fft(𝐱ᵀ, 1) # [x, in_chs, batch]
𝐱_selected = 𝐱_fft[1:m.modes, :, :] # [modes, in_chs, batch]

# [out_chs, modes, batch] <- [in_chs, modes, batch] [out_chs, in_chs, modes]
@tullio 𝐱_weighted[o, m, b] := 𝐱_selected[i, m, b] * m.weight[o, i, m]
# [modes, out_chs, batch] <- [modes, in_chs, batch] * [out_chs, in_chs, modes]
𝐱_weighted = ein_mul(𝐱_selected, m.weight)

s = size(𝐱_weighted)
d = size(𝐱, 2) - m.modes
𝐱_padded = cat(𝐱_weighted, zeros(ComplexF32, s[1], d, s[3:end]...), dims=2)
s = size(𝐱_weighted)[2:end]
d = size(𝐱ᵀ, 1) - m.modes
𝐱_padded = cat(𝐱_weighted, zeros(ComplexF32, d, s...), dims=1)

𝐱_out = ifft(𝐱_padded, 2)
𝐱_out = ifft(𝐱_padded, 1) # [x, out_chs, batch]
𝐱_outᵀ = t(𝐱_out) # [out_chs, x, batch] <- [x, out_chs, batch]

return m.σ.(real(𝐱_out))
return m.σ.(real(𝐱_outᵀ))
end

function FourierOperator(
Expand Down

0 comments on commit 609e6dd

Please sign in to comment.