Skip to content

Commit

Permalink
Merge pull request #82 from yuehhua/abstract
Browse files Browse the repository at this point in the history
Add abstract types for models and kernels
  • Loading branch information
foldfelis authored Jul 29, 2022
2 parents 63f0c58 + 77c0f6e commit 54602e6
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 17 deletions.
2 changes: 1 addition & 1 deletion src/DeepONet/DeepONet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ branch net: (Chain(Dense(2, 128), Dense(128, 64), Dense(64, 72)))
Trunk net: (Chain(Dense(1, 24), Dense(24, 72)))
```
"""
struct DeepONet{T1, T2}
struct DeepONet{T1, T2} <: AbstractOperatorModel
branch_net::T1
trunk_net::T2
end
Expand Down
55 changes: 41 additions & 14 deletions src/FNO/FNO.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@ export
FourierNeuralOperator,
MarkovNeuralOperator

struct FourierNeuralOperator{L, K, P} <: AbstractOperatorModel
lifting_net::L
integral_kernel_net::K
project_net::P
end

Flux.@functor FourierNeuralOperator

"""
FourierNeuralOperator(;
ch = (2, 64, 64, 64, 64, 64, 128, 1),
Expand Down Expand Up @@ -85,16 +93,31 @@ function FourierNeuralOperator(;
modes = (16,),
σ = gelu)
Transform = FourierTransform
lifting = Dense(ch[1], ch[2])
mapping = Chain(OperatorKernel(ch[2] => ch[3], modes, Transform, σ),
OperatorKernel(ch[3] => ch[4], modes, Transform, σ),
OperatorKernel(ch[4] => ch[5], modes, Transform, σ),
OperatorKernel(ch[5] => ch[6], modes, Transform))
project = Chain(Dense(ch[6], ch[7], σ),
Dense(ch[7], ch[8]))

return FourierNeuralOperator(lifting, mapping, project)
end

function (fno::FourierNeuralOperator)(𝐱::AbstractArray)
lifted = fno.lifting_net(𝐱)
mapped = fno.integral_kernel_net(lifted)
𝐲 = fno.project_net(mapped)

return Chain(Dense(ch[1], ch[2]),
OperatorKernel(ch[2] => ch[3], modes, Transform, σ),
OperatorKernel(ch[3] => ch[4], modes, Transform, σ),
OperatorKernel(ch[4] => ch[5], modes, Transform, σ),
OperatorKernel(ch[5] => ch[6], modes, Transform),
Dense(ch[6], ch[7], σ),
Dense(ch[7], ch[8]))
return 𝐲
end

struct MarkovNeuralOperator{F} <: AbstractOperatorModel
fno::F
end

Flux.@functor MarkovNeuralOperator

"""
MarkovNeuralOperator(;
ch = (1, 64, 64, 64, 64, 64, 1),
Expand Down Expand Up @@ -176,11 +199,15 @@ function MarkovNeuralOperator(;
modes = (24, 24),
σ = gelu)
Transform = FourierTransform

return Chain(Dense(ch[1], ch[2]),
OperatorKernel(ch[2] => ch[3], modes, Transform, σ),
OperatorKernel(ch[3] => ch[4], modes, Transform, σ),
OperatorKernel(ch[4] => ch[5], modes, Transform, σ),
OperatorKernel(ch[5] => ch[6], modes, Transform, σ),
Dense(ch[6], ch[7]))
lifting = Dense(ch[1], ch[2])
mapping = Chain(OperatorKernel(ch[2] => ch[3], modes, Transform, σ),
OperatorKernel(ch[3] => ch[4], modes, Transform, σ),
OperatorKernel(ch[4] => ch[5], modes, Transform, σ),
OperatorKernel(ch[5] => ch[6], modes, Transform, σ))
project = Dense(ch[6], ch[7])
fno = FourierNeuralOperator(lifting, mapping, project)

return MarkovNeuralOperator(fno)
end

(mno::MarkovNeuralOperator)(𝐱::AbstractArray) = mno.fno(𝐱)
2 changes: 1 addition & 1 deletion src/NOMAD/NOMAD.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
export NOMAD

struct NOMAD{T1, T2}
struct NOMAD{T1, T2} <: AbstractOperatorModel
approximator_net::T1
decoder_net::T2
end
Expand Down
2 changes: 2 additions & 0 deletions src/NeuralOperators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ using ChainRulesCore
using GeometricFlux
using Statistics

include("abstracttypes.jl")

# kernels
include("Transform/Transform.jl")
include("operator_kernel.jl")
Expand Down
6 changes: 6 additions & 0 deletions src/abstracttypes.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
export
AbstractOperatorModel,
AbstractOperatorKernel

abstract type AbstractOperatorModel end
abstract type AbstractOperatorKernel end
2 changes: 1 addition & 1 deletion src/operator_kernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ end
# operator #
############

struct OperatorKernel{L, C, F}
struct OperatorKernel{L, C, F} <: AbstractOperatorKernel
linear::L
conv::C
σ::F
Expand Down

0 comments on commit 54602e6

Please sign in to comment.