Skip to content

Commit

Permalink
Merge branch 'main' into rwclerr
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Aug 7, 2024
2 parents 2ff7a18 + c0c07c3 commit e1fc368
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 11 deletions.
36 changes: 25 additions & 11 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0))
[`Active`](@ref) will automatically convert plain integers to floating
point values, but cannot do so for integer values in tuples and structs.
"""
@inline function autodiff(::ReverseMode{ReturnPrimal, RABI,Holomorphic, ErrIfFuncWritten}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal, RABI<:ABI,Holomorphic, Nargs, ErrIfFuncWritten}
@inline function autodiff(rmode::ReverseMode{ReturnPrimal, RABI,Holomorphic, ErrIfFuncWritten}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal, RABI<:ABI,Holomorphic, Nargs, ErrIfFuncWritten}
tt′ = vaTypeof(args...)
width = same_or_one(1, args...)
if width == 0
Expand All @@ -239,17 +239,23 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0))
ModifiedBetween = Val(falses_from_args(Nargs+1))

tt = Tuple{map(T->eltype(Core.Typeof(T)), args)...}

rt = if A isa UnionAll
Core.Compiler.return_type(f.val, tt)
else
eltype(A)
end

FTy = Core.Typeof(f.val)

opt_mi = if RABI <: NonGenABI
Compiler.fspec(eltype(FA), tt′)
else
Val(codegen_world_age(Core.Typeof(f.val), tt))
Val(codegen_world_age(FTy, tt))
end

rt = if A isa UnionAll
@static if VERSION >= v"1.8.0"
Compiler.primal_return_type(rmode, Val(codegen_world_age(FTy, tt)), FTy, tt)
else
Core.Compiler.return_type(f.val, tt)
end
else
eltype(A)
end

if A <: Active
Expand Down Expand Up @@ -333,7 +339,11 @@ Like [`autodiff`](@ref) but will try to guess the activity of the return value.
"""
@inline function autodiff(mode::CMode, f::FA, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, CMode<:Mode, Nargs}
tt = Tuple{map(T->eltype(Core.Typeof(T)), args)...}
rt = Core.Compiler.return_type(f.val, tt)
rt = if mode isa ReverseMode && VERSION >= v"1.8.0"
Compiler.primal_return_type(mode, Val(codegen_world_age(eltype(FA), tt)), eltype(FA), tt)
else
Core.Compiler.return_type(f.val, tt)
end
A = guess_activity(rt, mode)
autodiff(mode, f, A, args...)
end
Expand Down Expand Up @@ -546,8 +556,12 @@ Like [`autodiff_deferred`](@ref) but will try to guess the activity of the retur

@inline function autodiff_deferred(mode::M, f::FA, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, M<:Mode, Nargs}
tt = Tuple{map(T->eltype(Core.Typeof(T)), args)...}
world = codegen_world_age(Core.Typeof(f.val), tt)
rt = Core.Compiler.return_type(f.val, tt)
rt = if mode isa ReverseMode && VERSION >= v"1.8.0"
Compiler.primal_return_type(mode, Val(codegen_world_age(eltype(FA), tt)), eltype(FA), tt)
else
Core.Compiler.return_type(f.val, tt)
end

if rt === Union{}
error("return type is Union{}, giving up.")
end
Expand Down
20 changes: 20 additions & 0 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3436,6 +3436,26 @@ Create the methodinstance pair, and lookup the primal return type.
return primal
end

@generated function primal_return_type(::ReverseMode, ::Val{world}, ::Type{FT}, ::Type{TT}) where {world, FT, TT}
mode = Enzyme.API.DEM_ReverseModeCombined
interp = Enzyme.Compiler.Interpreter.EnzymeInterpreter(Enzyme.Compiler.GLOBAL_REV_CACHE, nothing, world, mode)
res = Core.Compiler._return_type(interp, Tuple{FT, TT.parameters...})
return quote
Base.@_inline_meta
$res
end
end

@generated function primal_return_type(::ForwardMode, ::Val{world}, ::Type{FT}, ::Type{TT}) where {world, FT, TT}
mode = Enzyme.API.DEM_ForwardMode
interp = Enzyme.Compiler.Interpreter.EnzymeInterpreter(Enzyme.Compiler.GLOBAL_FWD_CACHE, nothing, world, mode)
res = Core.Compiler._return_type(interp, Tuple{FT, TT.parameters...})
return quote
Base.@_inline_meta
$res
end
end

##
# Enzyme compiler step
##
Expand Down

0 comments on commit e1fc368

Please sign in to comment.