Skip to content

Commit

Permalink
Add ABI setter (#1709)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Aug 7, 2024
1 parent ece399d commit d19307d
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 8 deletions.
4 changes: 4 additions & 0 deletions lib/EnzymeCore/src/EnzymeCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ const ReverseHolomorphicWithPrimal = ReverseMode{true,DefaultABI, true, false}()
@inline set_err_if_func_written(::ReverseMode{ReturnPrimal,ABI,Holomorphic,ErrIfFuncWritten}) where {ReturnPrimal,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{ReturnPrimal,ABI,Holomorphic,true}()
@inline clear_err_if_func_written(::ReverseMode{ReturnPrimal,ABI,Holomorphic,ErrIfFuncWritten}) where {ReturnPrimal,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{ReturnPrimal,ABI,Holomorphic,false}()

@inline set_abi(::ReverseMode{ReturnPrimal,OldABI,Holomorphic,ErrIfFuncWritten}, ::Type{NewABI}) where {ReturnPrimal,OldABI,Holomorphic,ErrIfFuncWritten,NewABI<:ABI} = ReverseMode{ReturnPrimal,NewABI,Holomorphic,ErrIfFuncWritten}()

"""
struct ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetween,ABI} <: Mode{ABI}
Expand Down Expand Up @@ -255,6 +257,8 @@ const Forward = ForwardMode{DefaultABI, false}()
@inline set_err_if_func_written(::ForwardMode{ABI,ErrIfFuncWritten}) where {ABI,ErrIfFuncWritten} = ForwardMode{ABI,true}()
@inline clear_err_if_func_written(::ForwardMode{ABI,ErrIfFuncWritten}) where {ABI,ErrIfFuncWritten} = ForwardMode{ABI,false}()

@inline set_abi(::ForwardMode{OldABI,ErrIfFuncWritten}, ::Type{NewABI}) where {OldABI,ErrIfFuncWritten,NewABI<:ABI} = ForwardMode{NewABI,ErrIfFuncWritten}()

function autodiff end
function autodiff_deferred end
function autodiff_thunk end
Expand Down
4 changes: 2 additions & 2 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ import EnzymeCore
import EnzymeCore: Forward, Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal, ReverseSplitModified, ReverseSplitWidth, ReverseMode, ForwardMode, ReverseHolomorphic, ReverseHolomorphicWithPrimal
export Forward, Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal, ReverseSplitModified, ReverseSplitWidth, ReverseMode, ForwardMode, ReverseHolomorphic, ReverseHolomorphicWithPrimal

import EnzymeCore: Annotation, Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, ABI, DefaultABI, FFIABI, InlineABI, NonGenABI, set_err_if_func_written, clear_err_if_func_written
export Annotation, Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, DefaultABI, FFIABI, InlineABI, NonGenABI, set_err_if_func_written, clear_err_if_func_written
import EnzymeCore: Annotation, Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, ABI, DefaultABI, FFIABI, InlineABI, NonGenABI, set_err_if_func_written, clear_err_if_func_written, set_abi
export Annotation, Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, DefaultABI, FFIABI, InlineABI, NonGenABI, set_err_if_func_written, clear_err_if_func_written, set_abi

import EnzymeCore: BatchDuplicatedFunc
export BatchDuplicatedFunc
Expand Down
12 changes: 6 additions & 6 deletions test/abi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ using Test
res = autodiff(Reverse, f, Const, Const(nothing))
@test res === ((nothing,),)

res = autodiff(ReverseMode{false,NonGenABI, false, false}(), f, Const, Const(nothing))
res = autodiff(Enzyme.set_abi(Reverse, NonGenABI), f, Const, Const(nothing))
@test res === ((nothing,),)

@test () === autodiff(Forward, f, Const, Const(nothing))
@test () === autodiff(ForwardMode{NonGenABI, false}(), f, Const, Const(nothing))
@test () === autodiff(Enzyme.set_abi(Forward, NonGenABI), f, Const, Const(nothing))

res = autodiff(Reverse, f, Const(nothing))
@test res === ((nothing,),)
Expand All @@ -22,11 +22,11 @@ using Test

res = autodiff_deferred(Reverse, f, Const(nothing))
@test res === ((nothing,),)
res = autodiff_deferred(ReverseMode{false,NonGenABI, false, false}(), f, Const, Const(nothing))
res = autodiff_deferred(Enzyme.set_abi(Reverse, NonGenABI), f, Const, Const(nothing))
@test res === ((nothing,),)

@test () === autodiff_deferred(Forward, f, Const(nothing))
@test () === autodiff_deferred(ForwardMode{NonGenABI, false}(), f, Const, Const(nothing))
@test () === autodiff_deferred(Enzyme.set_abi(Forward, NonGenABI), f, Const, Const(nothing))

# ConstType -> Type{Int}
res = autodiff(Reverse, f, Const, Const(Int))
Expand Down Expand Up @@ -65,15 +65,15 @@ using Test
_, res0 = autodiff(Reverse, unused, Active, Const(nothing), Active(2.0))[1]
@test res0 1.0

_, res0 = autodiff(ReverseMode{false, NonGenABI, false, false}(), unused, Active, Const(nothing), Active(2.0))[1]
_, res0 = autodiff(Enzyme.set_abi(Reverse, NonGenABI), unused, Active, Const(nothing), Active(2.0))[1]
@test res0 1.0

res0, = autodiff(Forward, unused, DuplicatedNoNeed, Const(nothing), Duplicated(2.0, 1.0))
@test res0 1.0
res0, = autodiff(Forward, unused, DuplicatedNoNeed, Const(nothing), DuplicatedNoNeed(2.0, 1.0))
@test res0 1.0

res0, = autodiff(ForwardMode{NonGenABI, false}(), unused, DuplicatedNoNeed, Const(nothing), Duplicated(2.0, 1.0))
res0, = autodiff(Enzyme.set_abi(Forward, NonGenABI), unused, DuplicatedNoNeed, Const(nothing), Duplicated(2.0, 1.0))
@test res0 1.0

_, res0 = autodiff(Reverse, unused, Const(nothing), Active(2.0))[1]
Expand Down

0 comments on commit d19307d

Please sign in to comment.