Skip to content
This repository has been archived by the owner on Sep 28, 2024. It is now read-only.

Commit

Permalink
refactor: structs for NeuralOperators (#23)
Browse files Browse the repository at this point in the history
* structs for Neural Operators

* bug fix

* bug fix

* fixing fno struct

* dispatch for TrainState

* removing structs for compact layers

* DeepONet : Compact => Container layers

* deeponet test bug fix

* deeponet fixes

* FNO : Compact => Container layers

* OperatorKernel : Compact => Container layers

* dispatch fixes from review

* fix: missing specializations

* fix: access AbstractExplicitContainerLayer from LuxCore

* fix: stop reexporting Lux

* test: AMDGPU tests are no longer broken

---------

Co-authored-by: Avik Pal <[email protected]>
  • Loading branch information
ayushinav and avik-pal authored Aug 20, 2024
1 parent 2f2de89 commit e248a56
Show file tree
Hide file tree
Showing 10 changed files with 129 additions and 104 deletions.
8 changes: 3 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,17 @@ LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"

[compat]
ArgCheck = "2.3.0"
ChainRulesCore = "1.24.0"
ConcreteStructs = "0.2.3"
FFTW = "1.8.0"
Lux = "0.5.62"
LuxCore = "0.1.21"
LuxLib = "0.3.40"
Lux = "0.5.64"
LuxCore = "0.1.24"
LuxLib = "0.3.42"
NNlib = "0.9.21"
Random = "1.10"
Reexport = "1.2.2"
WeightInitializers = "1"
julia = "1.10"
5 changes: 1 addition & 4 deletions src/NeuralOperators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,13 @@ using ChainRulesCore: ChainRulesCore, NoTangent
using ConcreteStructs: @concrete
using FFTW: FFTW, irfft, rfft
using Lux
using LuxCore: LuxCore, AbstractExplicitLayer
using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer
using LuxLib: batched_matmul
using NNlib: NNlib, batched_adjoint
using Random: Random, AbstractRNG
using Reexport: @reexport

const CRC = ChainRulesCore

@reexport using Lux

include("utils.jl")
include("transform.jl")

Expand Down
105 changes: 56 additions & 49 deletions src/deeponet.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
"""
DeepONet(; branch = (64, 32, 32, 16), trunk = (1, 8, 8, 16),
branch_activation = identity, trunk_activation = identity)
DeepONet(branch, trunk, additional)
Constructs a DeepONet composed of Dense layers. Make sure the last node of `branch` and
`trunk` are same.
Constructs a DeepONet from a `branch` and `trunk` architectures. Make sure that both the
nets output should have the same first dimension.
## Keyword arguments:
## Arguments
- `branch`: `Lux` network to be used as branch net.
- `trunk`: `Lux` network to be used as trunk net.
## Keyword Arguments
- `branch`: Tuple of integers containing the number of nodes in each layer for branch net
- `trunk`: Tuple of integers containing the number of nodes in each layer for trunk net
- `branch_activation`: activation function for branch net
- `trunk_activation`: activation function for trunk net
- `additional`: `Lux` network to pass the output of DeepONet, to include additional operations
for embeddings, defaults to `nothing`
Expand All @@ -23,7 +23,11 @@ operators", doi: https://arxiv.org/abs/1910.03193
## Example
```jldoctest
julia> deeponet = DeepONet(; branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16));
julia> branch_net = Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 16));
julia> trunk_net = Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16));
julia> deeponet = DeepONet(branch_net, trunk_net);
julia> ps, st = Lux.setup(Xoshiro(), deeponet);
Expand All @@ -35,37 +39,27 @@ julia> size(first(deeponet((u, y), ps, st)))
(10, 5)
```
"""
function DeepONet(;
branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16), branch_activation=identity,
trunk_activation=identity, additional=nothing)

# checks for last dimension size
@argcheck branch[end]==trunk[end] "Branch and Trunk net must share the same amount of \
nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ won't \
work."

branch_net = Chain([Dense(branch[i] => branch[i + 1], branch_activation)
for i in 1:(length(branch) - 1)]...)

trunk_net = Chain([Dense(trunk[i] => trunk[i + 1], trunk_activation)
for i in 1:(length(trunk) - 1)]...)

return DeepONet(branch_net, trunk_net; additional)
@concrete struct DeepONet <: AbstractExplicitContainerLayer{(:branch, :trunk, :additional)}
branch
trunk
additional
end

"""
DeepONet(branch, trunk)
DeepONet(branch, trunk) = DeepONet(branch, trunk, NoOpLayer())

Constructs a DeepONet from a `branch` and `trunk` architectures. Make sure that both the
nets output should have the same first dimension.
## Arguments
"""
DeepONet(; branch = (64, 32, 32, 16), trunk = (1, 8, 8, 16),
branch_activation = identity, trunk_activation = identity)
- `branch`: `Lux` network to be used as branch net.
- `trunk`: `Lux` network to be used as trunk net.
Constructs a DeepONet composed of Dense layers. Make sure the last node of `branch` and
`trunk` are same.
## Keyword Arguments
## Keyword arguments:
- `branch`: Tuple of integers containing the number of nodes in each layer for branch net
- `trunk`: Tuple of integers containing the number of nodes in each layer for trunk net
- `branch_activation`: activation function for branch net
- `trunk_activation`: activation function for trunk net
- `additional`: `Lux` network to pass the output of DeepONet, to include additional operations
for embeddings, defaults to `nothing`
Expand All @@ -78,11 +72,7 @@ operators", doi: https://arxiv.org/abs/1910.03193
## Example
```jldoctest
julia> branch_net = Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 16));
julia> trunk_net = Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16));
julia> deeponet = DeepONet(branch_net, trunk_net);
julia> deeponet = DeepONet(; branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16));
julia> ps, st = Lux.setup(Xoshiro(), deeponet);
Expand All @@ -94,15 +84,32 @@ julia> size(first(deeponet((u, y), ps, st)))
(10, 5)
```
"""
function DeepONet(branch::L1, trunk::L2; additional=nothing) where {L1, L2}
return @compact(; branch, trunk, additional, dispatch=:DeepONet) do (u, y)
t = trunk(y) # p x N x nb
b = branch(u) # p x u_size... x nb
function DeepONet(;
branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16), branch_activation=identity,
trunk_activation=identity, additional=NoOpLayer())

# checks for last dimension size
@argcheck branch[end]==trunk[end] "Branch and Trunk net must share the same amount of \
nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ won't \
work."

branch_net = Chain([Dense(branch[i] => branch[i + 1], branch_activation)
for i in 1:(length(branch) - 1)]...)

trunk_net = Chain([Dense(trunk[i] => trunk[i + 1], trunk_activation)
for i in 1:(length(trunk) - 1)]...)

return DeepONet(branch_net, trunk_net, additional)
end

function (deeponet::DeepONet)(x, ps, st::NamedTuple)
b, st_b = deeponet.branch(x[1], ps.branch, st.branch)
t, st_t = deeponet.trunk(x[2], ps.trunk, st.trunk)

@argcheck size(t, 1)==size(b, 1) "Branch and Trunk net must share the same \
amount of nodes in the last layer. Otherwise \
Σᵢ bᵢⱼ tᵢₖ won't work."
@argcheck size(b, 1)==size(t, 1) "Branch and Trunk net must share the same amount of \
nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ won't \
work."

@return __project(b, t, additional)
end
out, st_a = __project(b, t, deeponet.additional, (; ps=ps.additional, st=st.additional))
return out, (branch=st_b, trunk=st_t, additional=st_a)
end
29 changes: 21 additions & 8 deletions src/fno.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ kernels, and two `Dense` layers to project data back to the scalar field of inte
## Example
```jldoctest
julia> fno = FourierNeuralOperator(gelu; chs=(2, 64, 64, 128, 1), modes=(16,));
julia> fno = FourierNeuralOperator(; σ=gelu, chs=(2, 64, 64, 128, 1), modes=(16,));
julia> ps, st = Lux.setup(Xoshiro(), fno);
Expand All @@ -37,8 +37,15 @@ julia> size(first(fno(u, ps, st)))
(1, 1024, 5)
```
"""
function FourierNeuralOperator(
σ=gelu; chs::Dims{C}=(2, 64, 64, 64, 64, 64, 128, 1), modes::Dims{M}=(16,),
@concrete struct FourierNeuralOperator <:
AbstractExplicitContainerLayer{(:lifting, :mapping, :project)}
lifting
mapping
project
end

function FourierNeuralOperator(;
σ=gelu, chs::Dims{C}=(2, 64, 64, 64, 64, 64, 128, 1), modes::Dims{M}=(16,),
permuted::Val{perm}=Val(false), kwargs...) where {C, M, perm}
@argcheck length(chs) 5

Expand All @@ -52,9 +59,15 @@ function FourierNeuralOperator(
project = perm ? Chain(Conv(kernel_size, map₂, σ), Conv(kernel_size, map₃)) :
Chain(Dense(map₂, σ), Dense(map₃))

return Chain(; lifting,
mapping=Chain([SpectralKernel(chs[i] => chs[i + 1], modes, σ; permuted, kwargs...)
for i in 2:(C - 3)]...),
project,
name="FourierNeuralOperator")
mapping = Chain([SpectralKernel(chs[i] => chs[i + 1], modes, σ; permuted, kwargs...)
for i in 2:(C - 3)]...)

return FourierNeuralOperator(lifting, mapping, project)
end

function (fno::FourierNeuralOperator)(x::AbstractArray, ps, st::NamedTuple)
lift, st_lift = fno.lifting(x, ps.lifting, st.lifting)
mapping, st_mapping = fno.mapping(lift, ps.mapping, st.mapping)
project, st_project = fno.project(mapping, ps.project, st.project)
return project, (lifting=st_lift, mapping=st_mapping, project=st_project)
end
28 changes: 20 additions & 8 deletions src/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,18 +116,30 @@ julia> OperatorKernel(2 => 5, (16,), FourierTransform{ComplexF64}; permuted=Val(
```
"""
@concrete struct OperatorKernel <: AbstractExplicitContainerLayer{(:lin, :conv)}
lin
conv
activation <: Function
end

OperatorKernel(lin, conv) = OperatorKernel(lin, conv, identity)

function OperatorKernel(ch::Pair{<:Integer, <:Integer}, modes::Dims{N}, transform::Type{TR},
act::A=identity; allow_fast_activation::Bool=false, permuted::Val{perm}=Val(false),
kwargs...) where {N, TR <: AbstractTransform{<:Number}, perm, A}
act = allow_fast_activation ? NNlib.fast_act(act) : act
l₁ = perm ? Conv(map(_ -> 1, modes), ch) : Dense(ch)
l₂ = OperatorConv(ch, modes, transform; permuted, kwargs...)

return @compact(; l₁, l₂, activation=act, dispatch=:OperatorKernel) do x::AbstractArray
l₁x = l₁(x)
l₂x = l₂(x)
@return @. activation(l₁x + l₂x)
end
lin = perm ? Conv(map(_ -> 1, modes), ch) : Dense(ch)
conv = OperatorConv(ch, modes, transform; permuted, kwargs...)

return OperatorKernel(lin, conv, act)
end

function (op::OperatorKernel)(x::AbstractArray, ps, st::NamedTuple)
x_conv, st_conv = op.conv(x, ps.conv, st.conv)
x_lin, st_lin = op.lin(x, ps.lin, st.lin)

out = fast_activation!!(op.activation, x_conv .+ x_lin)
return out, (lin=st_lin, conv=st_conv)
end

"""
Expand Down
38 changes: 19 additions & 19 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
@inline function __project(b::AbstractArray{T1, 2}, t::AbstractArray{T2, 3},
additional::Nothing) where {T1, T2}
@inline function __project(
b::AbstractArray{T1, 2}, t::AbstractArray{T2, 3}, ::NoOpLayer, _) where {T1, T2}
# b : p x nb
# t : p x N x nb
b_ = reshape(b, size(b, 1), 1, size(b, 2)) # p x 1 x nb
return dropdims(sum(b_ .* t; dims=1); dims=1) # N x nb
return dropdims(sum(b_ .* t; dims=1); dims=1), () # N x nb
end

@inline function __project(b::AbstractArray{T1, 3}, t::AbstractArray{T2, 3},
additional::Nothing) where {T1, T2}
@inline function __project(
b::AbstractArray{T1, 3}, t::AbstractArray{T2, 3}, ::NoOpLayer, _) where {T1, T2}
# b : p x u x nb
# t : p x N x nb
if size(b, 2) == 1 || size(t, 2) == 1
return sum(b .* t; dims=1) # 1 x N x nb
return sum(b .* t; dims=1), () # 1 x N x nb
else
return batched_matmul(batched_adjoint(b), t) # u x N x b
return batched_matmul(batched_adjoint(b), t), () # u x N x b
end
end

@inline function __project(b::AbstractArray{T1, N}, t::AbstractArray{T2, 3},
additional::Nothing) where {T1, T2, N}
@inline function __project(
b::AbstractArray{T1, N}, t::AbstractArray{T2, 3}, ::NoOpLayer, _) where {T1, T2, N}
# b : p x u_size x nb
# t : p x N x nb
u_size = size(b)[2:(end - 1)]
Expand All @@ -29,34 +29,34 @@ end
t_ = reshape(t, size(t, 1), ones(eltype(u_size), length(u_size))..., size(t)[2:end]...)
# p x (1,1,1...) x N x nb

return dropdims(sum(b_ .* t_; dims=1); dims=1) # u_size x N x nb
return dropdims(sum(b_ .* t_; dims=1); dims=1), () # u_size x N x nb
end

@inline function __project(
b::AbstractArray{T1, 2}, t::AbstractArray{T2, 3}, additional::T) where {T1, T2, T}
@inline function __project(b::AbstractArray{T1, 2}, t::AbstractArray{T2, 3},
additional::T, params) where {T1, T2, T}
# b : p x nb
# t : p x N x nb
b_ = reshape(b, size(b, 1), 1, size(b, 2)) # p x 1 x nb
return additional(b_ .* t) # p x N x nb => out_dims x N x nb
return additional(b_ .* t, params.ps, params.st) # p x N x nb => out_dims x N x nb
end

@inline function __project(
b::AbstractArray{T1, 3}, t::AbstractArray{T2, 3}, additional::T) where {T1, T2, T}
@inline function __project(b::AbstractArray{T1, 3}, t::AbstractArray{T2, 3},
additional::T, params) where {T1, T2, T}
# b : p x u x nb
# t : p x N x nb

if size(b, 2) == 1 || size(t, 2) == 1
return additional(b .* t) # p x N x nb => out_dims x N x nb
return additional(b .* t, params.ps, params.st) # p x N x nb => out_dims x N x nb
else
b_ = reshape(b, size(b)[1:2]..., 1, size(b, 3)) # p x u x 1 x nb
t_ = reshape(t, size(t, 1), 1, size(t)[2:end]...) # p x 1 x N x nb

return additional(b_ .* t_) # p x u x N x nb => out_size x N x nb
return additional(b_ .* t_, params.ps, params.st) # p x u x N x nb => out_size x N x nb
end
end

@inline function __project(b::AbstractArray{T1, N}, t::AbstractArray{T2, 3},
additional::T) where {T1, T2, N, T}
additional::T, params) where {T1, T2, N, T}
# b : p x u_size x nb
# t : p x N x nb
u_size = size(b)[2:(end - 1)]
Expand All @@ -67,5 +67,5 @@ end
t_ = reshape(t, size(t, 1), ones(eltype(u_size), length(u_size))..., size(t)[2:end]...)
# p x (1,1,1...) x N x nb

return additional(b_ .* t_) # p x u_size x N x nb => out_size x N x nb
return additional(b_ .* t_, params.ps, params.st) # p x u_size x N x nb => out_size x N x nb
end
6 changes: 3 additions & 3 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ Documenter = "1.5.0"
ExplicitImports = "1.9.0"
Hwloc = "3.2.0"
InteractiveUtils = "<0.0.1, 1"
Lux = "0.5.62"
LuxCore = "0.1.21"
LuxLib = "0.3.40"
Lux = "0.5.64"
LuxCore = "0.1.24"
LuxLib = "0.3.42"
LuxTestUtils = "1.1.2"
MLDataDevices = "1.0.0"
Optimisers = "0.3.3"
Expand Down
3 changes: 1 addition & 2 deletions test/fno_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,10 @@
@test size(first(fno(x, ps, st))) == setup.y_size

data = [(x, y)]
broken = mode == "AMDGPU"
@test begin
l2, l1 = train!(fno, ps, st, data; epochs=10)
l2 < l1
end broken=broken
end

__f = (x, ps) -> sum(abs2, first(fno(x, ps, st)))
test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3,
Expand Down
Loading

0 comments on commit e248a56

Please sign in to comment.