From f71b6d83fb08f8b57165aedb795f19ec2c9b8df0 Mon Sep 17 00:00:00 2001 From: Charles Kawczynski Date: Sun, 26 Jan 2025 23:23:32 -0500 Subject: [PATCH] Fix reinstantiation of spectral broadcasted --- src/Operators/spectralelement.jl | 6 ++- test/Operators/unit_reinstantiate_bc.jl | 53 +++++++++++++++++++++++++ test/runtests.jl | 1 + 3 files changed, 59 insertions(+), 1 deletion(-) create mode 100644 test/Operators/unit_reinstantiate_bc.jl diff --git a/src/Operators/spectralelement.jl b/src/Operators/spectralelement.jl index aa93d48c98..284b9c0eec 100644 --- a/src/Operators/spectralelement.jl +++ b/src/Operators/spectralelement.jl @@ -127,7 +127,9 @@ function Base.Broadcast.instantiate(sbc::SpectralBroadcasted) Base.Broadcast.check_broadcast_axes(axes, args...) end end - op = typeof(op)(axes) + # If we've already instantiated, then we need to strip the type parameters, + # for example, `Divergence{()}(axes)`. + op = unionall_type(typeof(op)){()}(axes) Style = AbstractSpectralStyle(ClimaComms.device(axes)) return SpectralBroadcasted{Style}(op, args, axes) end @@ -1323,6 +1325,7 @@ struct Interpolate{I, S} <: TensorOperator space::S end Interpolate(space) = Interpolate{operator_axes(space), typeof(space)}(space) +Interpolate{()}(space) = Interpolate{operator_axes(space), typeof(space)}(space) function apply_operator(op::Interpolate{(1,)}, space_out, slabidx, arg) FT = Spaces.undertype(space_out) @@ -1412,6 +1415,7 @@ struct Restrict{I, S} <: TensorOperator space::S end Restrict(space) = Restrict{operator_axes(space), typeof(space)}(space) +Restrict{()}(space) = Restrict{operator_axes(space), typeof(space)}(space) function apply_operator(op::Restrict{(1,)}, space_out, slabidx, arg) FT = Spaces.undertype(space_out) diff --git a/test/Operators/unit_reinstantiate_bc.jl b/test/Operators/unit_reinstantiate_bc.jl new file mode 100644 index 0000000000..7637b34f97 --- /dev/null +++ b/test/Operators/unit_reinstantiate_bc.jl @@ -0,0 +1,53 @@ +#= +julia --project=.buildkite +using Revise; include("test/Operators/unit_reinstantiate_bc.jl") +=# + +# TODO: make this unit test more low-level +using ClimaComms +ClimaComms.@import_required_backends +using ClimaCore.CommonSpaces +using ClimaCore: Spaces, Fields, Geometry, ClimaCore, Operators +using LazyBroadcast: lazy +using Test +using Base.Broadcast: materialize + +const divₕ = Operators.Divergence() +const wgradₕ = Operators.WeakGradient() +const curlₕ = Operators.Curl() +const wcurlₕ = Operators.WeakCurl() + +using ClimaCore.CommonSpaces + +function foo_tendency_uₕ(ᶜuₕ, zmax) + return lazy.( + @. ( + wgradₕ(divₕ(ᶜuₕ)) - Geometry.project( + Geometry.Covariant12Axis(), + wcurlₕ(Geometry.project(Geometry.Covariant3Axis(), curlₕ(ᶜuₕ))), + ) + ) + ) +end + +@testset "Reinstantiation of SpectralBroadcasted" begin + FT = Float64 + ᶜspace = ExtrudedCubedSphereSpace( + FT; + z_elem = 10, + z_min = 0, + z_max = 1, + radius = 10, + h_elem = 10, + n_quad_points = 4, + staggering = CellCenter(), + ) + ᶠspace = Spaces.face_space(ᶜspace) + ᶠz = Fields.coordinate_field(ᶠspace).z + ᶜz = Fields.coordinate_field(ᶜspace).z + ᶜuₕ = map(z -> zero(Geometry.Covariant12Vector{eltype(z)}), ᶜz) + zmax = Spaces.z_max(axes(ᶠz)) + vst_uₕ = foo_tendency_uₕ(ᶜuₕ, zmax) + ᶜuₕₜ = zero(ᶜuₕ) + @. ᶜuₕₜ += vst_uₕ +end diff --git a/test/runtests.jl b/test/runtests.jl index 74e5c79154..a64d8b2843 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -44,6 +44,7 @@ UnitTest("Spaces - DSS cubed sphere" ,"Spaces/ddss1_cs.jl"), UnitTest("Sphere spaces" ,"Spaces/sphere.jl"), # UnitTest("Terrain warp" ,"Spaces/terrain_warp.jl"), # appears to hang on GHA UnitTest("Fields" ,"Fields/unit_field.jl"), # has benchmarks +UnitTest("Reinstantiate broadcasted" ,"Operators/unit_reinstantiate_bc.jl"), UnitTest("Spectral elem - rectilinear" ,"Operators/spectralelement/rectilinear.jl"), UnitTest("Spectral elem - opt" ,"Operators/spectralelement/opt.jl"), UnitTest("Spectral elem - gradient tensor" ,"Operators/spectralelement/covar_deriv_ops.jl"),