From 2b04d559bafd1abf28014553eaa825cd32a33998 Mon Sep 17 00:00:00 2001 From: Charles Kawczynski Date: Fri, 4 Oct 2024 14:25:39 -0400 Subject: [PATCH] Try out TypeGroupedNamedTuple --- src/cache/precomputed_quantities.jl | 73 ++++++++++++++++++++++++++++- 1 file changed, 72 insertions(+), 1 deletion(-) diff --git a/src/cache/precomputed_quantities.jl b/src/cache/precomputed_quantities.jl index c109e80df4..1f42cfd9b1 100644 --- a/src/cache/precomputed_quantities.jl +++ b/src/cache/precomputed_quantities.jl @@ -4,6 +4,76 @@ import Thermodynamics as TD import ClimaCore: Spaces, Fields +struct UniformNamedTuple{K, T, N} + k::NTuple{N, K} + v::NTuple{N, T} +end +@inline Base.getproperty(nt::UniformNamedTuple, sym::Symbol) = + getproperty(nt, Val(sym)) +to_named_tuple(nt::UniformNamedTuple) = (; zip(nt.k, nt.v)...) +# Does this need to be compile-time known? +@inline function Base.getproperty(nt::UniformNamedTuple, sym::Symbol) + i = findfirst(s -> s == sym, getfield(nt, :k)) + if isnothing(i) + error("No property $sym found in $(to_named_tuple(nt))") + else + @inbounds getproperty(getfield(nt, :v), i) + end +end +Base.eltype(::UniformNamedTuple{K, T, N}) where {K, T, N} = K +Base.length(::UniformNamedTuple{K, T, N}) where {K, T, N} = N +function to_uniform_named_tuple(nt) + k = Tuple(keys(nt)) + v = Tuple(values(nt)) + K = typeof(first(k)) + T = typeof(first(v)) + N = length(v) + return UniformNamedTuple{K, T, N}(k, v) +end + +struct TypeGroupedNamedTuple{M, C} + tnmap::M + cache::C +end +TypeGroupedNamedTuple() = + TypeGroupedNamedTuple{Nothing, Nothing}(nothing, nothing) + +function Base.getproperty(nt::TypeGroupedNamedTuple, sym::Symbol) + cache = getfield(nt, :cache) + tnmap = getfield(nt, :tnmap) + type_grouped_cache = getfield(cache, tnmap[sym]) + getproperty(type_grouped_cache, sym) +end +process_key(k) = Symbol(k) + +to_named_tuple(dict::Dict) = (; (process_key(k) => v for (k, v) in dict)...) + +function type_grouped_named_tuple(flat_cache::NamedTuple) + type_grouped_named_tuple(collect(zip(keys(flat_cache), values(flat_cache)))) +end + +function type_grouped_named_tuple(flat_cache) + type_grouped_cache = Dict{Symbol, Any}() + for (i, (sym, c)) in enumerate(flat_cache) + H = process_key(hash(typeof(c))) + entry = Pair(sym, c) + if haskey(type_grouped_cache, H) + type_grouped_cache[H] = (; type_grouped_cache[H]..., sym => c) + else + type_grouped_cache[H] = (; sym => c) + end + end + type_grouped_cache = + map(x -> to_uniform_named_tuple(x), to_named_tuple(type_grouped_cache)) + K = map(symc -> symc[1], flat_cache) + V = map(symc -> process_key(hash(typeof(symc[2]))), flat_cache) + tnmap = Dict(pairs((; zip(K, V)...))) + return TypeGroupedNamedTuple{typeof(tnmap), typeof(type_grouped_cache)}( + tnmap, + type_grouped_cache, + ) +end + """ precomputed_quantities(Y, atmos) @@ -156,7 +226,7 @@ function precomputed_quantities(Y, atmos) ᶜqᵣ = similar(Y.c, FT), ᶜqₛ = similar(Y.c, FT), ) : (;) - return (; + nt = (; gs_quantities..., sgs_quantities..., advective_sgs_quantities..., @@ -165,6 +235,7 @@ function precomputed_quantities(Y, atmos) precipitation_quantities..., cloud_diagnostics_tuple, ) + return type_grouped_named_tuple(nt) end # Interpolates the third contravariant component of Y.c.uₕ to cell faces.