Skip to content

Commit

Permalink
Merge pull request #92 from bcube-project/dev_dual_support
Browse files Browse the repository at this point in the history
Better support of different kind of DoF types with `FEFunction` and `assemble_linear`
  • Loading branch information
ghislainb authored May 14, 2024
2 parents da8542d + 845003b commit 40b51de
Show file tree
Hide file tree
Showing 10 changed files with 97 additions and 40 deletions.
8 changes: 6 additions & 2 deletions src/assembler/assembler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,12 @@ julia> assemble_linear(l, V)
0.25
```
"""
function assemble_linear(l::Function, V::Union{TestFESpace, AbstractMultiTestFESpace})
b = zeros(get_ndofs(V)) # TODO : specify the eltype (Float64, Dual,...)
function assemble_linear(
l::Function,
V::Union{TestFESpace, AbstractMultiTestFESpace};
T = Float64,
)
b = zeros(T, get_ndofs(V))
assemble_linear!(b, l, V)
return b
end
Expand Down
24 changes: 24 additions & 0 deletions src/cellfunction/cellfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -637,3 +637,27 @@ LazyOperators.materialize(f::Function, ::AbstractLazy) = f
LazyOperators.materialize(a::AbstractArray{<:Number}, ::AbstractLazy) = a
LazyOperators.materialize(a::Number, ::AbstractLazy) = a
LazyOperators.materialize(a::LinearAlgebra.UniformScaling, ::AbstractLazy) = a

"""
get_return_type_and_codim(f::AbstractLazy, elementInfo::AbstractDomainIndex)
get_return_type_and_codim(f::AbstractLazy, domain::AbstractDomain)
get_return_type_and_codim(f::AbstractLazy, mesh::AbstractMesh)
Evaluate the returned type and the codimension of `materialize(f,x)` where
`x` is a `ElementPoint` from `elementInfo` (or `domain`/`mesh`).
The returned codimension is always a Tuple, even for a scalar.
"""
function get_return_type_and_codim(f::AbstractLazy, elementInfo::AbstractDomainIndex)
fₑ = materialize(f, elementInfo)
elementPoint = get_dummy_element_point(elementInfo)
value = materialize(fₑ, elementPoint)
N = value isa Number ? (1,) : size(value)
T = eltype(value)
return T, N
end
function get_return_type_and_codim(f::AbstractLazy, domain::AbstractDomain)
get_return_type_and_codim(f, first(DomainIterator(domain)))
end
function get_return_type_and_codim(f::AbstractLazy, mesh::AbstractMesh)
get_return_type_and_codim(f, CellInfo(mesh, 1))
end
27 changes: 27 additions & 0 deletions src/cellfunction/eval_point.jl
Original file line number Diff line number Diff line change
Expand Up @@ -216,3 +216,30 @@ function change_domain(p::FacePoint{ReferenceDomain}, ::PhysicalDomain)
x_phy = _apply_mapping(m, get_coords(p))
FacePoint(x_phy, faceInfo, PhysicalDomain())
end

ElementPoint(x, elementInfo::CellInfo, ds) = CellPoint(x, elementInfo, ds)
ElementPoint(x, elementInfo::FaceInfo, ds) = FacePoint(x, elementInfo, ds)

"""
get_dummy_element_point(elementInfo::AbstractDomainIndex)
get_dummy_element_point(domain::AbstractDomain)
Return a `CellPoint` (or a `FacePoint` depending on the
type of `elementInfo`) whose coordinates are equals
to the center of the reference shape of the element.
For a `domain` argument, the point is built from its firt element.
# Devs notes:
These utility functions can be used to easily materialize
a `CellFunction` and get the type of the result for example.
"""
function get_dummy_element_point(elementInfo::AbstractDomainIndex)
x = center(shape(get_element_type(elementInfo)))
ElementPoint(x, elementInfo, ReferenceDomain())
end

function get_dummy_element_point(domain::AbstractDomain)
elementInfo = first(DomainIterator(domain))
get_dummy_element_point(elementInfo)
end
2 changes: 1 addition & 1 deletion src/feoperator/limiter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ function linear_scaling_limiter(
mass = nothing,
) where {N, Me, BC <: PeriodicBCType}
lim_u, u̅ = linear_scaling_limiter_coef(u, dω, bounds, DMPrelax, periodicBCs)
u_lim = FEFunction(get_fespace(u))
u_lim = FEFunction(get_fespace(u), get_dof_type(u))
projection_l2!(u_lim, u̅ + lim_u * (u - u̅), dω; mass = mass)
lim_u, u_lim
end
27 changes: 2 additions & 25 deletions src/feoperator/projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ all the data obtained from the neighbor cells of this node. We could use
the surface area (among other possible choices).
"""
function var_on_vertices(f::AbstractLazy, mesh::Mesh)
N, T = _codim_and_type(f, mesh)
T, N = get_return_type_and_codim(f, mesh)
@assert length(N) <= 1 "N = $(length(N)) > 1 not supported yet"
values = zeros(T, nnodes(mesh), N[1])
_var_on_vertices!(values, f, mesh)
Expand Down Expand Up @@ -51,29 +51,6 @@ function _var_on_vertices!(values, f::AbstractLazy, mesh::Mesh)
end
end

"""
_codim_and_type(f::AbstractLazy, mesh::Mesh)
Evaluate the codimension of `f` and returned type. The returned codimension
is always a Tuple of codimension(s), even for a scalar.
"""
function _codim_and_type(f::AbstractLazy, mesh::Mesh)
# Get info about first cell of the mesh
cInfo = CellInfo(mesh, 1)

# Materialize FE function on CellInfo
_f = materialize(f, cInfo)

# Evaluate the function on the center of the cell
cPoint = CellPoint(center(shape(celltype(cInfo))), cInfo, ReferenceDomain())
value = materialize(_f, cPoint)

# Codim and type
N = value isa Number ? (1,) : size(value)
T = eltype(value)
return N, T
end

"""
var_on_centers(f::AbstractLazy, mesh::AbstractMesh)
Expand All @@ -82,7 +59,7 @@ Interpolate solution on mesh centers.
The result is a (ncells, ncomps) matrix if ncomps > 1, or a (ncells) vector otherwise.
"""
function var_on_centers(f::AbstractLazy, mesh::AbstractMesh)
N, T = _codim_and_type(f, mesh)
T, N = get_return_type_and_codim(f, mesh)
@assert length(N) <= 1 "N = $(length(N)) > 1 not supported yet"
values = zeros(T, ncells(mesh), N[1])
_var_on_centers!(values, f, mesh)
Expand Down
5 changes: 3 additions & 2 deletions src/feoperator/projection_newapi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ function projection_l2!(u::AbstractSingleFieldFEFunction, f, dΩ::Measure; mass
A = mass
end
l(v) = (f v)dΩ
b = assemble_linear(l, V)
T, = get_return_type_and_codim(f, get_domain(dΩ))
b = assemble_linear(l, V; T = T)
x = A \ b
set_dof_values!(u, x)
return nothing
Expand Down Expand Up @@ -105,7 +106,7 @@ function cell_mean(u::MultiFieldFEFunction, cache::Tuple{Vararg{CellMeanCache}})
end
function cell_mean(u::AbstractFEFunction, cache::CellMeanCache)
Umean = get_fespace(cache)
u_mean = FEFunction(Umean)
u_mean = FEFunction(Umean, get_dof_type(u))
projection_l2!(u_mean, u, get_measure(cache); mass = get_mass_matrix(cache))
values = _reshape_cell_mean(u_mean, Val(get_size(Umean)))
return MeshCellData(values)
Expand Down
20 changes: 16 additions & 4 deletions src/fespace/fefunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ end
function get_fespace(f::AbstractFEFunction)
error("`get_fespace` is not defined for $(typeof(f))")
end

function get_dof_type(f::AbstractFEFunction)
error("`get_dof_type` is not defined for type $(typeof(f))")
end
function get_dof_values(f::AbstractFEFunction)
error("`get_dof_values` is not defined for type $(typeof(f))")
end
Expand Down Expand Up @@ -83,18 +85,24 @@ function SingleFieldFEFunction(feSpace::AbstractFESpace, dofValues)
return SingleFieldFEFunction{size, FE, V}(feSpace, dofValues)
end

function FEFunction(feSpace::AbstractFESpace, dofValues = allocate_dofs(feSpace))
function FEFunction(feSpace::AbstractFESpace, dofValues)
SingleFieldFEFunction(feSpace, dofValues)
end

function FEFunction(feSpace::AbstractFESpace, T::Type{<:Number} = Float64)
dofValues = allocate_dofs(feSpace, T)
FEFunction(feSpace, dofValues)
end

function FEFunction(feSpace::AbstractFESpace, constant::Number)
feFunction = FEFunction(feSpace)
feFunction = FEFunction(feSpace, typeof(constant))
feFunction.dofValues .= constant
return feFunction
end

get_fespace(f::SingleFieldFEFunction) = f.feSpace
get_ncomponents(f::SingleFieldFEFunction) = get_ncomponents(get_fespace(f))
get_dof_type(f::SingleFieldFEFunction) = eltype(get_dof_values(f))
get_dof_values(f::SingleFieldFEFunction) = f.dofValues
function get_dof_values(f::SingleFieldFEFunction, icell)
feSpace = get_fespace(f)
Expand Down Expand Up @@ -140,6 +148,10 @@ end
end
@inline _get_mfe_space(mfeFunc::MultiFieldFEFunction) = mfeFunc.mfeSpace

function get_dof_type(mfeFunc::MultiFieldFEFunction)
mapreduce(get_dof_type, promote, (mfeFunc...,))
end

"""
Update the vector `u` with the values of each `FEFunction` composing this MultiFieldFEFunction.
The mapping of the associated MultiFESpace is respected.
Expand All @@ -151,7 +163,7 @@ function get_dof_values!(u::AbstractVector{<:Number}, mfeFunc::MultiFieldFEFunct
end

function get_dof_values(mfeFunc::MultiFieldFEFunction)
u = allocate_dofs(_get_mfe_space(mfeFunc))
u = mapreduce(get_dof_values, vcat, (mfeFunc...,))
get_dof_values!(u, mfeFunc)
return u
end
Expand Down
14 changes: 13 additions & 1 deletion src/mesh/domain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,13 @@ function BoundaryFaceDomain(mesh::AbstractMesh, args...; kwargs...)
BoundaryFaceDomain(parent(mesh), args...; kwargs...)
end

"""
abstract type AbstractDomainIndex
# Devs notes
All subtypes should implement the following functions:
- get_element_type(::AbstractDomainIndex)
"""
abstract type AbstractDomainIndex end

abstract type AbstractCellInfo <: AbstractDomainIndex end
Expand All @@ -250,6 +257,7 @@ end
@inline celltype(c::CellInfo) = c.ctype
@inline nodes(c::CellInfo) = c.nodes
@inline get_nodes_index(c::CellInfo) = c.c2n
get_element_type(c::CellInfo) = celltype(c)

""" Legacy constructor for CellInfo : no information about node indices """
CellInfo(icell, ctype, nodes) = CellInfo(icell, ctype, nodes, nothing)
Expand Down Expand Up @@ -285,6 +293,9 @@ end
@inline celltype(c::CellSide) = c.ctype
@inline nodes(c::CellSide) = c.nodes
@inline cell2nodes(c::CellSide) = c.c2n
get_element_type(c::CellSide) = celltype(c)

abstract type AbstractFaceInfo <: AbstractDomainIndex end

"""
FaceInfo{CN<:CellInfo,CP<:CellInfo,FT,FN,F2N,I}
Expand All @@ -302,7 +313,7 @@ are duplicate from the negative ones.
is stored explicitely in `FaceInfo` even if it could have been
computed by collecting info from the side of the negative or positive cells.
"""
struct FaceInfo{CN <: CellInfo, CP <: CellInfo, FT, FN, F2N, I}
struct FaceInfo{CN <: CellInfo, CP <: CellInfo, FT, FN, F2N, I} <: AbstractFaceInfo
cellinfo_n::CN
cellinfo_p::CP
cellside_n::Int
Expand Down Expand Up @@ -379,6 +390,7 @@ get_cellinfo_p(faceInfo::FaceInfo) = faceInfo.cellinfo_p
@inline get_nodes_index(faceInfo::FaceInfo) = faceInfo.f2n
get_cell_side_n(faceInfo::FaceInfo) = faceInfo.cellside_n
get_cell_side_p(faceInfo::FaceInfo) = faceInfo.cellside_p
get_element_type(c::FaceInfo) = facetype(c)

"""
Return the opposite side of the `FaceInfo` : cellside "n" because cellside "p"
Expand Down
4 changes: 2 additions & 2 deletions src/writers/vtk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -386,11 +386,11 @@ function write_vtk_lagrange(
nd = ndofs(dhl_export)

# Get ncomps and type of each `point` variable
dimtype = map(var -> _codim_and_type(var, mesh), values(vars_point))
type_dim = map(var -> get_return_type_and_codim(var, mesh), values(vars_point))

# VTK stuff
coords_vtk = zeros(spacedim(mesh), nd)
values_vtk = map(_dimtype -> zeros(last(_dimtype), first(_dimtype)..., nd), dimtype)
values_vtk = map(((_t, _d),) -> zeros(_t, _d..., nd), type_dim)
cells_vtk = MeshCell[]
sizehint!(cells_vtk, ncells(mesh))
nodeweigth_vtk = zeros(nd)
Expand Down
6 changes: 3 additions & 3 deletions test/interpolation/test_projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,13 @@

@testset "misc" begin
mesh = one_cell_mesh(:quad)
N, T = Bcube._codim_and_type(PhysicalFunction(x -> x[1]), mesh)
T, N = Bcube.get_return_type_and_codim(PhysicalFunction(x -> x[1]), mesh)
@test N == (1,)
@test T == Float64
N, T = Bcube._codim_and_type(PhysicalFunction(x -> 1), mesh)
T, N = Bcube.get_return_type_and_codim(PhysicalFunction(x -> 1), mesh)
@test N == (1,)
@test T == Int
N, T = Bcube._codim_and_type(PhysicalFunction(x -> x), mesh)
T, N = Bcube.get_return_type_and_codim(PhysicalFunction(x -> x), mesh)
@test N == (2,)
@test T == Float64
end
Expand Down

0 comments on commit 40b51de

Please sign in to comment.