Skip to content

Commit

Permalink
Merge pull request #2124 from CliMA/ck/save_stretch
Browse files Browse the repository at this point in the history
Save the stretch type in the IntervalMesh
  • Loading branch information
charleskawczynski authored Feb 2, 2025
2 parents 71025be + 70fd6f4 commit 83b6ccc
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 24 deletions.
37 changes: 29 additions & 8 deletions src/InputOutput/readers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,17 @@ import ..Geometry:
Covariant3Vector
using ..DataLayouts

function get_key(file, key)
if haskey(file, key)
return file[key]
else
msg = "Key `$key` not found in HDF5Reader.\n"
msg *= "Available keys:\n"
msg *= string(keys(file))
error(msg)
end
end

"""
HDF5Reader(filename::AbstractString[, context::ClimaComms.AbstractCommsContext])
HDF5Reader(::Function, filename::AbstractString[, context::ClimaComms.AbstractCommsContext])
Expand Down Expand Up @@ -259,23 +270,32 @@ function read_mesh(reader, name)
end

function read_mesh_new(reader::HDF5Reader, name::AbstractString)
group = reader.file["meshes/$name"]
group = get_key(reader.file, "meshes/$name")
type = attrs(group)["type"]
if type == "IntervalMesh"
domain = read_domain(reader, attrs(group)["domain"])
nelements = attrs(group)["nelements"]
faces_type = attrs(group)["faces_type"]
if faces_type == "Range"
faces_type = get(attrs(group), "faces_type", nothing)
stretch_type = get(attrs(group), "stretch_type", nothing)
if stretch_type == "Uniform" || faces_type == "Range"
return Meshes.IntervalMesh(
domain,
Meshes.Uniform(),
Meshes.Uniform();
nelems = nelements,
)
else
end
if stretch_type "UnknownStretch"
stretch_params = attrs(group)["stretch_params"]
CT = Domains.coordinate_type(domain)
faces = [CT(coords) for coords in attrs(group)["faces"]]
return Meshes.IntervalMesh(domain, faces)
stretch =
getproperty(Meshes, Symbol(stretch_type))(stretch_params...)
return Meshes.IntervalMesh(domain, stretch; nelems = nelements)
end
# Fallback: read from array
@assert faces_type == "Array"
CT = Domains.coordinate_type(domain)
faces = [CT(coords) for coords in attrs(group)["faces"]]
return Meshes.IntervalMesh(domain, faces)
elseif type == "RectilinearMesh"
intervalmesh1 = read_mesh(reader, attrs(group)["intervalmesh1"])
intervalmesh2 = read_mesh(reader, attrs(group)["intervalmesh2"])
Expand Down Expand Up @@ -480,7 +500,8 @@ cached, so that reading the same field multiple times will create multiple
distinct objects.
"""
function read_field(reader::HDF5Reader, name::AbstractString)
obj = reader.file["fields/$name"]
key = "fields/$name"
obj = get_key(reader.file, key)
type = attrs(obj)["type"]
if type == "Field"
if haskey(attrs(obj), "grid")
Expand Down
9 changes: 9 additions & 0 deletions src/InputOutput/writers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,15 @@ function write_new!(
[getfield(mesh.faces[i], 1) for i in 1:length(mesh.faces)],
)
end
(; stretch) = mesh
write_attribute(group, "stretch_type", string(nameof(typeof(stretch))))
fns = fieldnames(typeof(stretch))
if !isempty(fns)
vals = map(fns) do fn
getfield(stretch, fn)
end
write_attribute(group, "stretch_params", [vals...])
end
return name
end

Expand Down
30 changes: 24 additions & 6 deletions src/Meshes/interval.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@ Constuct a 1D mesh on `domain` with `nelems` elements, using `stretching`. Possi
- [`GeneralizedExponentialStretching(dz_bottom, dz_top)`](@ref)
- [`HyperbolicTangentStretching(dz_bottom)`](@ref)
"""
struct IntervalMesh{I <: IntervalDomain, V <: AbstractVector} <: AbstractMesh1D
struct IntervalMesh{S, I <: IntervalDomain, V <: AbstractVector, M} <:
AbstractMesh1D
stretch::S
domain::I
faces::V
meta::M
end

# implies isequal
Expand Down Expand Up @@ -126,6 +129,21 @@ end

abstract type StretchingRule end

"""
UnknownStretch()
An unknown stretch rule, to be used when constructing an `IntervalMesh` with
given faces.
"""
struct UnknownStretch end

function IntervalMesh(domain::IntervalDomain, faces::AbstractArray)
nelems = length(faces)
nelems < 1 && throw(ArgumentError("`nelems` must be ≥ 1"))
monotonic_check(faces)
IntervalMesh(UnknownStretch(), domain, faces, nothing)
end

"""
Uniform()
Expand All @@ -135,15 +153,15 @@ struct Uniform <: StretchingRule end

function IntervalMesh(
domain::IntervalDomain{CT},
::Uniform = Uniform();
stretch::Uniform = Uniform();
nelems::Int,
) where {CT <: Geometry.Abstract1DPoint{FT}} where {FT}
if nelems < 1
throw(ArgumentError("`nelems` must be ≥ 1"))
end
faces = range(domain.coord_min, domain.coord_max; length = nelems + 1)
monotonic_check(faces)
IntervalMesh(domain, faces)
IntervalMesh(stretch, domain, faces, nothing)
end


Expand Down Expand Up @@ -199,7 +217,7 @@ function IntervalMesh(
reverse!(faces)
end
monotonic_check(faces)
IntervalMesh(domain, faces)
IntervalMesh(stretch, domain, faces, nothing)
end

"""
Expand Down Expand Up @@ -320,7 +338,7 @@ function IntervalMesh(
faces[end] = faces[end] == -z_bottom ? z_bottom : faces[1]
end
monotonic_check(faces)
IntervalMesh(domain, CT.(faces))
IntervalMesh(stretch, domain, CT.(faces), (; h_top))
end

"""
Expand Down Expand Up @@ -409,7 +427,7 @@ function IntervalMesh(
faces[end] = faces[end] == -z_bottom ? z_bottom : faces[1]
end
monotonic_check(faces)
IntervalMesh(domain, CT.(faces))
IntervalMesh(stretch, domain, CT.(faces), (; γ_sol = γ_sol.root))
end

"""
Expand Down
57 changes: 47 additions & 10 deletions test/InputOutput/unit_finitedifference.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
#=
julia --project=.buildkite
using Revise; include("test/InputOutput/unit_finitedifference.jl")
=#
using Test
import ClimaCore
import ClimaCore.Fields
using ClimaCore: Fields, Meshes, Geometry, Grids, CommonSpaces, InputOutput
using ClimaCore: Domains, Topologies, Spaces

using ClimaComms
const comms_ctx = ClimaComms.context(ClimaComms.CPUSingleThreaded())
Expand All @@ -16,17 +21,17 @@ end
z_min = FT(0)
z_max = FT(30e3)
z_elem = 10
center_staggering = ClimaCore.Grids.CellCenter()
face_staggering = ClimaCore.Grids.CellFace()
center_staggering = Grids.CellCenter()
face_staggering = Grids.CellFace()

center_space = ClimaCore.CommonSpaces.ColumnSpace(;
center_space = CommonSpaces.ColumnSpace(;
z_min,
z_max,
z_elem,
staggering = center_staggering,
)

face_space = ClimaCore.CommonSpaces.ColumnSpace(;
face_space = CommonSpaces.ColumnSpace(;
z_min,
z_max,
z_elem,
Expand All @@ -36,15 +41,47 @@ end
center_field = Fields.local_geometry_field(center_space)
face_field = Fields.local_geometry_field(face_space)

Y = ClimaCore.Fields.FieldVector(; c = center_field, f = face_field)
Y = Fields.FieldVector(; c = center_field, f = face_field)

# write field vector to hdf5 file
ClimaCore.InputOutput.HDF5Writer(filename, comms_ctx) do writer
ClimaCore.InputOutput.write!(writer, Y, "Y")
InputOutput.HDF5Writer(filename, comms_ctx) do writer
InputOutput.write!(writer, Y, "Y")
end

ClimaCore.InputOutput.HDF5Reader(filename, comms_ctx) do reader
restart_Y = ClimaCore.InputOutput.read_field(reader, "Y") # read fieldvector from hdf5 file
InputOutput.HDF5Reader(filename, comms_ctx) do reader
restart_Y = InputOutput.read_field(reader, "Y") # read fieldvector from hdf5 file
@test restart_Y == Y # test if restart is exact
end
end

@testset "HDF5 restart test for 1d finite difference space with unknown mesh" begin
FT = Float32
z_min = FT(0)
z_max = FT(30e3)
z_elem = 10
center_staggering = Grids.CellCenter()
face_staggering = Grids.CellFace()

vdomain = Domains.IntervalDomain(
Geometry.ZPoint{FT}(0.0),
Geometry.ZPoint{FT}(10e3);
boundary_names = (:bottom, :top),
)
vmesh = Meshes.IntervalMesh(vdomain; nelems = 45)
vmesh = Meshes.IntervalMesh(vdomain, vmesh.faces) # pass in faces directly
@test vmesh.stretch isa Meshes.UnknownStretch
context = ClimaComms.context()
vtopology = Topologies.IntervalTopology(context, vmesh)
vspace = Spaces.CenterFiniteDifferenceSpace(vtopology)
f = Fields.Field(FT, vspace)

# write field vector to hdf5 file
InputOutput.HDF5Writer(filename, comms_ctx) do writer
InputOutput.write!(writer, f, "f")
end

InputOutput.HDF5Reader(filename, comms_ctx) do reader
restart_f = InputOutput.read_field(reader, "f") # read field from hdf5 file
@test axes(restart_f).grid.topology.mesh.faces == vmesh.faces
end
end

0 comments on commit 83b6ccc

Please sign in to comment.