From 4e4fd28491b79916b023d35264043ed461439fd7 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Wed, 8 Jan 2025 11:12:24 +0000 Subject: [PATCH] Fix Remaining Generated Function Problems (#439) * Fix more generated functions * Excise another generated function * Fix can_produce_... * Fix zero_rdata * Bump patch version --- Project.toml | 2 +- src/fwds_rvs_data.jl | 116 +++++++++++++++++++++++++++++-------------- 2 files changed, 79 insertions(+), 39 deletions(-) diff --git a/Project.toml b/Project.toml index 86c5d1c78..28e472172 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Mooncake" uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" authors = ["Will Tebbutt, Hong Ge, and contributors"] -version = "0.4.73" +version = "0.4.74" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/fwds_rvs_data.jl b/src/fwds_rvs_data.jl index de70b5467..584227fa0 100644 --- a/src/fwds_rvs_data.jl +++ b/src/fwds_rvs_data.jl @@ -259,9 +259,8 @@ function fdata(t::T) where {T<:PossiblyUninitTangent} return is_init(t) ? F(fdata(val(t))) : F() end -@generated function fdata(t::Union{Tuple,NamedTuple}) - fdata_type(t) == NoFData && return NoFData() - return :(tuple_map(fdata, t)) +function fdata(t::T) where {T<:Union{Tuple,NamedTuple}} + return fdata_type(T) == NoFData ? NoFData() : tuple_map(fdata, t) end uninit_fdata(p) = fdata(uninit_tangent(p)) @@ -515,14 +514,37 @@ function rdata(t::T) where {T<:PossiblyUninitTangent} end @generated function rdata(t::Union{Tuple,NamedTuple}) - rdata_type(t) == NoRData && return NoRData() - return :(tuple_map(rdata, t)) + return :(rdata_type($t) == NoRData ? NoRData() : tuple_map(rdata, t)) end -function rdata_backing_type(::Type{P}) where {P} - rdata_field_types = map(n -> rdata_field_type(P, n), 1:fieldcount(P)) - all(==(NoRData), rdata_field_types) && return NoRData - return NamedTuple{fieldnames(P),Tuple{rdata_field_types...}} +""" + rdata_field_types_exprs(::Type{P}) where {P} + +Tuple of expressions. The nth computes the rdata backing type of the nth field of `P`. +""" +function rdata_field_types_exprs(::Type{P}) where {P} + return map(1:fieldcount(P), always_initialised(P)) do n, init + Pf = fieldtype(P, n) + if init + return :(rdata_type(tangent_type($Pf))) + else + return :(PossiblyUninitTangent{rdata_type(tangent_type($Pf))}) + end + end +end + +""" + rdata_backing_type(::Type{P}) where {P} + +The type of the field of `RData` for `P`. +""" +@generated function rdata_backing_type(::Type{P}) where {P} + rdata_field_types_expr = Expr(:call, :tuple, rdata_field_types_exprs(P)...) + return quote + rdata_field_types = $rdata_field_types_expr + stable_all(tuple_map(==(NoRData()), rdata_field_types)) && return NoRData + return NamedTuple{$(fieldnames(P)),Tuple{rdata_field_types...}} + end end """ @@ -535,34 +557,42 @@ zero_rdata(p) zero_rdata(p::IEEEFloat) = zero(p) @generated function zero_rdata(p::P) where {P} - - # Get types associated to primal. - T = tangent_type(P) - R = rdata_type(T) - - # If there's no reverse data, return no reverse data, e.g. for mutable types. - R == NoRData && return :(NoRData()) - - # T ought to be a `Tangent`. If it's not, something has gone wrong. - !(T <: Tangent) && return Expr(:call, error, "Unhandled type $T") - rdata_field_zeros_exprs = ntuple(fieldcount(P)) do n - R_field = rdata_field_type(P, n) - if R_field <: PossiblyUninitTangent - return :(isdefined(p, $n) ? $R_field(zero_rdata(getfield(p, $n))) : $R_field()) - else + Rs = rdata_field_types_exprs(P) + rdata_field_zeros_exprs = map(1:fieldcount(P), always_initialised(P), Rs) do n, init, R + if init return :(zero_rdata(getfield(p, $n))) + else + return quote + R_field = $R + isdefined(p, $n) ? R_field(zero_rdata(getfield(p, $n))) : R_field() + end end end backing_data_expr = Expr(:call, :tuple, rdata_field_zeros_exprs...) - backing_expr = :($(rdata_backing_type(P))($backing_data_expr)) - return Expr(:call, R, backing_expr) + backing_expr = :(rdata_backing_type($P)($backing_data_expr)) + + return quote + # Get types associated to primal. + T = tangent_type($P) + R = rdata_type(T) + + # If there's no reverse data, return no reverse data, e.g. for mutable types. + R == NoRData && return NoRData() + + # T ought to be a `Tangent`. If it's not, something has gone wrong. + T <: Tangent || error("Unhandled type $T") + + # return $backing_expr + return R($backing_expr) + end end -@generated function zero_rdata(p::Union{Tuple,NamedTuple}) - rdata_type(tangent_type(p)) == NoRData && return NoRData() - return :(tuple_map(zero_rdata, p)) +function zero_rdata(p::P) where {P<:Union{Tuple,NamedTuple}} + return rdata_type(tangent_type(P)) == NoRData ? NoRData() : tuple_map(zero_rdata, p) end +has_definite_fieldcount(P) = P isa DataType && Base.datatype_fieldcount(P) !== nothing + """ can_produce_zero_rdata_from_type(::Type{P}) where {P} @@ -570,15 +600,25 @@ Returns whether or not the zero element of the rdata type for primal type `P` ca obtained from `P` alone. """ @generated function can_produce_zero_rdata_from_type(::Type{P}) where {P} - R = rdata_type(tangent_type(P)) - R == NoRData && return true - isabstracttype(P) && return false - (isconcretetype(P) || P <: Tuple) || return false - (P <: Tuple && !(P isa DataType)) && return false # catch Unions and UnionAlls - (P <: Tuple && Base.datatype_fieldcount(P) === nothing) && return false - - # For general structs, just look at their fields. - return isstructtype(P) ? all(can_produce_zero_rdata_from_type, fieldtypes(P)) : false + if isstructtype(P) && has_definite_fieldcount(P) + can_produces = map(_P -> :(can_produce_zero_rdata_from_type($_P)), fieldtypes(P)) + else + can_produces = () + end + tuple_expr = Expr(:call, :tuple, can_produces...) + + return quote + R = rdata_type(tangent_type($P)) + R == NoRData && return true + $(isabstracttype(P)) && return false + $(isconcretetype(P) || P <: Tuple) || return false + $(P <: Tuple && !(P isa DataType)) && return false # catch Unions and UnionAlls + $(P <: Tuple && !has_definite_fieldcount(P)) && return false + + # For general structs, just look at their fields. + $(!isstructtype(P)) && return false + return stable_all($tuple_expr) + end end can_produce_zero_rdata_from_type(::Type{<:IEEEFloat}) = true