diff --git a/src/cache/precomputed_quantities.jl b/src/cache/precomputed_quantities.jl index c109e80df4..a761f60d1a 100644 --- a/src/cache/precomputed_quantities.jl +++ b/src/cache/precomputed_quantities.jl @@ -4,6 +4,77 @@ import Thermodynamics as TD import ClimaCore: Spaces, Fields +struct UniformNamedTuple{K, T} + k::Vector{K} + v::Vector{T} +end +# @inline Base.getproperty(nt::UniformNamedTuple, sym::Symbol) = +# getproperty(nt, Val(sym)) +to_named_tuple(nt::UniformNamedTuple) = (; zip(nt.k, nt.v)...) +# Does this need to be compile-time known? +@inline function Base.getproperty(nt::UniformNamedTuple, sym::Symbol) + i = findfirst(s -> s == sym, getfield(nt, :k)) + if isnothing(i) + error("No property $sym found in $(to_named_tuple(nt))") + else + @inbounds getindex(getfield(nt, :v), i) + end +end +Base.eltype(::UniformNamedTuple{K, T}) where {K, T} = K +Base.length(nt::UniformNamedTuple{K, T}) where {K, T} = length(getfield(nt, :v)) +function to_uniform_named_tuple(nt) + k = collect(Tuple(keys(nt))) + v = collect(Tuple(values(nt))) + K = typeof(first(k)) + T = typeof(first(v)) + return UniformNamedTuple{K, T}(k, v) +end + +struct TypeGroupedNamedTuple{M, C} + tnmap::M + cache::C +end +TypeGroupedNamedTuple() = + TypeGroupedNamedTuple{Nothing, Nothing}(nothing, nothing) + +function Base.getproperty(nt::TypeGroupedNamedTuple, sym::Symbol) + cache = getfield(nt, :cache) + tnmap = getfield(nt, :tnmap) + type_grouped_cache = getfield(cache, tnmap[sym]) + getproperty(type_grouped_cache, sym) +end +process_key(k) = Symbol(k) + +to_named_tuple(dict::Dict) = (; (process_key(k) => v for (k, v) in dict)...) + +function type_grouped_named_tuple(flat_cache::NamedTuple) + type_grouped_named_tuple(collect(zip(keys(flat_cache), values(flat_cache)))) +end + +Base.propertynames(nt::TypeGroupedNamedTuple) = + Tuple(keys(getfield(nt, :tnmap))) +function type_grouped_named_tuple(flat_cache) + type_grouped_cache = Dict{Symbol, Any}() + for (i, (sym, c)) in enumerate(flat_cache) + H = process_key(hash(typeof(c))) + entry = Pair(sym, c) + if haskey(type_grouped_cache, H) + type_grouped_cache[H] = (; type_grouped_cache[H]..., sym => c) + else + type_grouped_cache[H] = (; sym => c) + end + end + type_grouped_cache = + map(x -> to_uniform_named_tuple(x), to_named_tuple(type_grouped_cache)) + K = map(symc -> symc[1], flat_cache) + V = map(symc -> process_key(hash(typeof(symc[2]))), flat_cache) + tnmap = Dict(pairs((; zip(K, V)...))) + return TypeGroupedNamedTuple{typeof(tnmap), typeof(type_grouped_cache)}( + tnmap, + type_grouped_cache, + ) +end + """ precomputed_quantities(Y, atmos) @@ -156,7 +227,7 @@ function precomputed_quantities(Y, atmos) ᶜqᵣ = similar(Y.c, FT), ᶜqₛ = similar(Y.c, FT), ) : (;) - return (; + nt = (; gs_quantities..., sgs_quantities..., advective_sgs_quantities..., @@ -165,6 +236,9 @@ function precomputed_quantities(Y, atmos) precipitation_quantities..., cloud_diagnostics_tuple, ) + tgnt = type_grouped_named_tuple(nt) + @show typeof(tgnt) + return tgnt end # Interpolates the third contravariant component of Y.c.uₕ to cell faces.