Skip to content

Commit

Permalink
add function get_return_type_and_codim
Browse files Browse the repository at this point in the history
  • Loading branch information
ghislainb committed May 13, 2024
1 parent 1fb75e1 commit 845003b
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 46 deletions.
28 changes: 22 additions & 6 deletions src/cellfunction/cellfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -638,10 +638,26 @@ LazyOperators.materialize(a::AbstractArray{<:Number}, ::AbstractLazy) = a
LazyOperators.materialize(a::Number, ::AbstractLazy) = a
LazyOperators.materialize(a::LinearAlgebra.UniformScaling, ::AbstractLazy) = a

# TODO : find a better place and a better granularity
function get_return_type(f, domain::AbstractDomain)
f1 = materialize(f, first(DomainIterator(domain)))
elementPoint = get_dummy_element_point(domain)
value = materialize(f1, elementPoint)
return eltype(value)
"""
get_return_type_and_codim(f::AbstractLazy, elementInfo::AbstractDomainIndex)
get_return_type_and_codim(f::AbstractLazy, domain::AbstractDomain)
get_return_type_and_codim(f::AbstractLazy, mesh::AbstractMesh)
Evaluate the returned type and the codimension of `materialize(f,x)` where
`x` is a `ElementPoint` from `elementInfo` (or `domain`/`mesh`).
The returned codimension is always a Tuple, even for a scalar.
"""
function get_return_type_and_codim(f::AbstractLazy, elementInfo::AbstractDomainIndex)
fₑ = materialize(f, elementInfo)
elementPoint = get_dummy_element_point(elementInfo)
value = materialize(fₑ, elementPoint)
N = value isa Number ? (1,) : size(value)
T = eltype(value)
return T, N
end
function get_return_type_and_codim(f::AbstractLazy, domain::AbstractDomain)
get_return_type_and_codim(f, first(DomainIterator(domain)))
end
function get_return_type_and_codim(f::AbstractLazy, mesh::AbstractMesh)
get_return_type_and_codim(f, CellInfo(mesh, 1))
end
23 changes: 15 additions & 8 deletions src/cellfunction/eval_point.jl
Original file line number Diff line number Diff line change
Expand Up @@ -221,18 +221,25 @@ ElementPoint(x, elementInfo::CellInfo, ds) = CellPoint(x, elementInfo, ds)
ElementPoint(x, elementInfo::FaceInfo, ds) = FacePoint(x, elementInfo, ds)

Check warning on line 221 in src/cellfunction/eval_point.jl

View check run for this annotation

Codecov / codecov/patch

src/cellfunction/eval_point.jl#L221

Added line #L221 was not covered by tests

"""
get_dummy_element_point(elementInfo::AbstractDomainIndex)
get_dummy_element_point(domain::AbstractDomain)
Return a `CellPoint` (or a `FacePoint` depending on the
type of `domain`) associated with the first element on the
domain and whose coordinates correspond to the origin of
its `ReferenceDomain`.
This utility function can be used to easily materialize
a `CellFunction` and know the type of the result for example.
type of `elementInfo`) whose coordinates are equals
to the center of the reference shape of the element.
For a `domain` argument, the point is built from its firt element.
# Devs notes:
These utility functions can be used to easily materialize
a `CellFunction` and get the type of the result for example.
"""
function get_dummy_element_point(elementInfo::AbstractDomainIndex)
x = center(shape(get_element_type(elementInfo)))
ElementPoint(x, elementInfo, ReferenceDomain())
end

function get_dummy_element_point(domain::AbstractDomain)
mesh = get_mesh(domain)
elementInfo = first(DomainIterator(domain))
xnode1 = get_coords(get_nodes(mesh, 1))
ElementPoint(zero(xnode1), elementInfo, ReferenceDomain())
get_dummy_element_point(elementInfo)

Check warning on line 244 in src/cellfunction/eval_point.jl

View check run for this annotation

Codecov / codecov/patch

src/cellfunction/eval_point.jl#L242-L244

Added lines #L242 - L244 were not covered by tests
end
27 changes: 2 additions & 25 deletions src/feoperator/projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ all the data obtained from the neighbor cells of this node. We could use
the surface area (among other possible choices).
"""
function var_on_vertices(f::AbstractLazy, mesh::Mesh)
N, T = _codim_and_type(f, mesh)
T, N = get_return_type_and_codim(f, mesh)
@assert length(N) <= 1 "N = $(length(N)) > 1 not supported yet"
values = zeros(T, nnodes(mesh), N[1])
_var_on_vertices!(values, f, mesh)
Expand Down Expand Up @@ -51,29 +51,6 @@ function _var_on_vertices!(values, f::AbstractLazy, mesh::Mesh)
end
end

"""
_codim_and_type(f::AbstractLazy, mesh::Mesh)
Evaluate the codimension of `f` and returned type. The returned codimension
is always a Tuple of codimension(s), even for a scalar.
"""
function _codim_and_type(f::AbstractLazy, mesh::Mesh)
# Get info about first cell of the mesh
cInfo = CellInfo(mesh, 1)

# Materialize FE function on CellInfo
_f = materialize(f, cInfo)

# Evaluate the function on the center of the cell
cPoint = CellPoint(center(shape(celltype(cInfo))), cInfo, ReferenceDomain())
value = materialize(_f, cPoint)

# Codim and type
N = value isa Number ? (1,) : size(value)
T = eltype(value)
return N, T
end

"""
var_on_centers(f::AbstractLazy, mesh::AbstractMesh)
Expand All @@ -82,7 +59,7 @@ Interpolate solution on mesh centers.
The result is a (ncells, ncomps) matrix if ncomps > 1, or a (ncells) vector otherwise.
"""
function var_on_centers(f::AbstractLazy, mesh::AbstractMesh)
N, T = _codim_and_type(f, mesh)
T, N = get_return_type_and_codim(f, mesh)
@assert length(N) <= 1 "N = $(length(N)) > 1 not supported yet"
values = zeros(T, ncells(mesh), N[1])
_var_on_centers!(values, f, mesh)
Expand Down
3 changes: 2 additions & 1 deletion src/feoperator/projection_newapi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ function projection_l2!(u::AbstractSingleFieldFEFunction, f, dΩ::Measure; mass
A = mass
end
l(v) = (f v)dΩ
b = assemble_linear(l, V; T = get_return_type(f, get_domain(dΩ)))
T, = get_return_type_and_codim(f, get_domain(dΩ))
b = assemble_linear(l, V; T = T)
x = A \ b
set_dof_values!(u, x)
return nothing
Expand Down
14 changes: 13 additions & 1 deletion src/mesh/domain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,13 @@ function BoundaryFaceDomain(mesh::AbstractMesh, args...; kwargs...)
BoundaryFaceDomain(parent(mesh), args...; kwargs...)
end

"""
abstract type AbstractDomainIndex
# Devs notes
All subtypes should implement the following functions:
- get_element_type(::AbstractDomainIndex)
"""
abstract type AbstractDomainIndex end

abstract type AbstractCellInfo <: AbstractDomainIndex end
Expand All @@ -250,6 +257,7 @@ end
@inline celltype(c::CellInfo) = c.ctype
@inline nodes(c::CellInfo) = c.nodes
@inline get_nodes_index(c::CellInfo) = c.c2n
get_element_type(c::CellInfo) = celltype(c)

""" Legacy constructor for CellInfo : no information about node indices """
CellInfo(icell, ctype, nodes) = CellInfo(icell, ctype, nodes, nothing)
Expand Down Expand Up @@ -285,6 +293,9 @@ end
@inline celltype(c::CellSide) = c.ctype
@inline nodes(c::CellSide) = c.nodes
@inline cell2nodes(c::CellSide) = c.c2n
get_element_type(c::CellSide) = celltype(c)

Check warning on line 296 in src/mesh/domain.jl

View check run for this annotation

Codecov / codecov/patch

src/mesh/domain.jl#L296

Added line #L296 was not covered by tests

abstract type AbstractFaceInfo <: AbstractDomainIndex end

"""
FaceInfo{CN<:CellInfo,CP<:CellInfo,FT,FN,F2N,I}
Expand All @@ -302,7 +313,7 @@ are duplicate from the negative ones.
is stored explicitely in `FaceInfo` even if it could have been
computed by collecting info from the side of the negative or positive cells.
"""
struct FaceInfo{CN <: CellInfo, CP <: CellInfo, FT, FN, F2N, I}
struct FaceInfo{CN <: CellInfo, CP <: CellInfo, FT, FN, F2N, I} <: AbstractFaceInfo
cellinfo_n::CN
cellinfo_p::CP
cellside_n::Int
Expand Down Expand Up @@ -379,6 +390,7 @@ get_cellinfo_p(faceInfo::FaceInfo) = faceInfo.cellinfo_p
@inline get_nodes_index(faceInfo::FaceInfo) = faceInfo.f2n
get_cell_side_n(faceInfo::FaceInfo) = faceInfo.cellside_n
get_cell_side_p(faceInfo::FaceInfo) = faceInfo.cellside_p
get_element_type(c::FaceInfo) = facetype(c)

Check warning on line 393 in src/mesh/domain.jl

View check run for this annotation

Codecov / codecov/patch

src/mesh/domain.jl#L393

Added line #L393 was not covered by tests

"""
Return the opposite side of the `FaceInfo` : cellside "n" because cellside "p"
Expand Down
4 changes: 2 additions & 2 deletions src/writers/vtk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -386,11 +386,11 @@ function write_vtk_lagrange(
nd = ndofs(dhl_export)

# Get ncomps and type of each `point` variable
dimtype = map(var -> _codim_and_type(var, mesh), values(vars_point))
type_dim = map(var -> get_return_type_and_codim(var, mesh), values(vars_point))

# VTK stuff
coords_vtk = zeros(spacedim(mesh), nd)
values_vtk = map(_dimtype -> zeros(last(_dimtype), first(_dimtype)..., nd), dimtype)
values_vtk = map(((_t, _d),) -> zeros(_t, _d..., nd), type_dim)
cells_vtk = MeshCell[]
sizehint!(cells_vtk, ncells(mesh))
nodeweigth_vtk = zeros(nd)
Expand Down
6 changes: 3 additions & 3 deletions test/interpolation/test_projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,13 @@

@testset "misc" begin
mesh = one_cell_mesh(:quad)
N, T = Bcube._codim_and_type(PhysicalFunction(x -> x[1]), mesh)
T, N = Bcube.get_return_type_and_codim(PhysicalFunction(x -> x[1]), mesh)
@test N == (1,)
@test T == Float64
N, T = Bcube._codim_and_type(PhysicalFunction(x -> 1), mesh)
T, N = Bcube.get_return_type_and_codim(PhysicalFunction(x -> 1), mesh)
@test N == (1,)
@test T == Int
N, T = Bcube._codim_and_type(PhysicalFunction(x -> x), mesh)
T, N = Bcube.get_return_type_and_codim(PhysicalFunction(x -> x), mesh)
@test N == (2,)
@test T == Float64
end
Expand Down

0 comments on commit 845003b

Please sign in to comment.