Skip to content

Commit

Permalink
Properly extend methods
Browse files Browse the repository at this point in the history
  • Loading branch information
sloede committed Nov 3, 2023
1 parent 5c663fc commit fb91bd9
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 19 deletions.
2 changes: 1 addition & 1 deletion src/TrixiSmartShockFinder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module TrixiSmartShockFinder
using MuladdMacro: @muladd
using Trixi
using Trixi: AbstractIndicator, AbstractEquations, AbstractSemidiscretization, @threaded,
trixi_include
summary_box, trixi_include

include("indicators.jl")
include("indicators_1d.jl")
Expand Down
16 changes: 8 additions & 8 deletions src/indicators_1d.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@

# this method is used when the indicator is constructed as for shock-capturing volume integrals
# empty cache is default
function create_cache(::Type{<:IndicatorNeuralNetwork},
equations::AbstractEquations{1}, basis::LobattoLegendreBasis)
function Trixi.create_cache(::Type{<:IndicatorNeuralNetwork},
equations::AbstractEquations{1}, basis::LobattoLegendreBasis)
return NamedTuple()
end

# cache for NeuralNetworkPerssonPeraire-type indicator
function create_cache(::Type{IndicatorNeuralNetwork{NeuralNetworkPerssonPeraire}},
equations::AbstractEquations{1}, basis::LobattoLegendreBasis)
function Trixi.create_cache(::Type{IndicatorNeuralNetwork{NeuralNetworkPerssonPeraire}},
equations::AbstractEquations{1}, basis::LobattoLegendreBasis)
alpha = Vector{real(basis)}()
alpha_tmp = similar(alpha)
A = Array{real(basis), ndims(equations)}
Expand All @@ -27,8 +27,8 @@ function create_cache(::Type{IndicatorNeuralNetwork{NeuralNetworkPerssonPeraire}
end

# cache for NeuralNetworkRayHesthaven-type indicator
function create_cache(::Type{IndicatorNeuralNetwork{NeuralNetworkRayHesthaven}},
equations::AbstractEquations{1}, basis::LobattoLegendreBasis)
function Trixi.create_cache(::Type{IndicatorNeuralNetwork{NeuralNetworkRayHesthaven}},
equations::AbstractEquations{1}, basis::LobattoLegendreBasis)
alpha = Vector{real(basis)}()
alpha_tmp = similar(alpha)
A = Array{real(basis), ndims(equations)}
Expand All @@ -41,8 +41,8 @@ function create_cache(::Type{IndicatorNeuralNetwork{NeuralNetworkRayHesthaven}},
end

# this method is used when the indicator is constructed as for AMR
function create_cache(typ::Type{<:IndicatorNeuralNetwork},
mesh, equations::AbstractEquations{1}, dg::DGSEM, cache)
function Trixi.create_cache(typ::Type{<:IndicatorNeuralNetwork},
mesh, equations::AbstractEquations{1}, dg::DGSEM, cache)
create_cache(typ, equations, dg.basis)
end

Expand Down
20 changes: 10 additions & 10 deletions src/indicators_2d.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@

# this method is used when the indicator is constructed as for shock-capturing volume integrals
# empty cache is default
function create_cache(::Type{IndicatorNeuralNetwork},
equations::AbstractEquations{2}, basis::LobattoLegendreBasis)
function Trixi.create_cache(::Type{IndicatorNeuralNetwork},
equations::AbstractEquations{2}, basis::LobattoLegendreBasis)
return NamedTuple()
end

# cache for NeuralNetworkPerssonPeraire-type indicator
function create_cache(::Type{IndicatorNeuralNetwork{NeuralNetworkPerssonPeraire}},
equations::AbstractEquations{2}, basis::LobattoLegendreBasis)
function Trixi.create_cache(::Type{IndicatorNeuralNetwork{NeuralNetworkPerssonPeraire}},
equations::AbstractEquations{2}, basis::LobattoLegendreBasis)
alpha = Vector{real(basis)}()
alpha_tmp = similar(alpha)
A = Array{real(basis), ndims(equations)}
Expand All @@ -30,8 +30,8 @@ function create_cache(::Type{IndicatorNeuralNetwork{NeuralNetworkPerssonPeraire}
end

# cache for NeuralNetworkRayHesthaven-type indicator
function create_cache(::Type{IndicatorNeuralNetwork{NeuralNetworkRayHesthaven}},
equations::AbstractEquations{2}, basis::LobattoLegendreBasis)
function Trixi.create_cache(::Type{IndicatorNeuralNetwork{NeuralNetworkRayHesthaven}},
equations::AbstractEquations{2}, basis::LobattoLegendreBasis)
alpha = Vector{real(basis)}()
alpha_tmp = similar(alpha)
A = Array{real(basis), ndims(equations)}
Expand All @@ -50,8 +50,8 @@ function create_cache(::Type{IndicatorNeuralNetwork{NeuralNetworkRayHesthaven}},
end

# cache for NeuralNetworkCNN-type indicator
function create_cache(::Type{IndicatorNeuralNetwork{NeuralNetworkCNN}},
equations::AbstractEquations{2}, basis::LobattoLegendreBasis)
function Trixi.create_cache(::Type{IndicatorNeuralNetwork{NeuralNetworkCNN}},
equations::AbstractEquations{2}, basis::LobattoLegendreBasis)
alpha = Vector{real(basis)}()
alpha_tmp = similar(alpha)
A = Array{real(basis), ndims(equations)}
Expand All @@ -69,8 +69,8 @@ function create_cache(::Type{IndicatorNeuralNetwork{NeuralNetworkCNN}},
end

# this method is used when the indicator is constructed as for AMR
function create_cache(typ::Type{<:IndicatorNeuralNetwork},
mesh, equations::AbstractEquations{2}, dg::DGSEM, cache)
function Trixi.create_cache(typ::Type{<:IndicatorNeuralNetwork},
mesh, equations::AbstractEquations{2}, dg::DGSEM, cache)
create_cache(typ, equations, dg.basis)
end

Expand Down

0 comments on commit fb91bd9

Please sign in to comment.