Skip to content

Commit

Permalink
perf: (Enzyme) remove tangent conversion, use native gradient when po…
Browse files Browse the repository at this point in the history
…ssible (#730)

* perf: remove tangent conversion with Enzyme,  use native gradient when possible

* Drop tests

* bump
  • Loading branch information
gdalle authored Feb 17, 2025
1 parent 3417ff5 commit 1df5621
Show file tree
Hide file tree
Showing 8 changed files with 136 additions and 172 deletions.
2 changes: 1 addition & 1 deletion DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DifferentiationInterface"
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
authors = ["Guillaume Dalle", "Adrian Hill"]
version = "0.6.41"
version = "0.6.42"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ function DI.value_and_pushforward(
) where {F,C}
mode = forward_withprimal(backend)
f_and_df = get_f_and_df(f, backend, mode)
dx_sametype = convert(typeof(x), only(tx))
x_and_dx = Duplicated(x, dx_sametype)
dx = only(tx)
x_and_dx = Duplicated(x, dx)
annotated_contexts = translate(backend, mode, Val(1), contexts...)
dy, y = autodiff(mode, f_and_df, x_and_dx, annotated_contexts...)
return y, (dy,)
Expand All @@ -37,8 +37,7 @@ function DI.value_and_pushforward(
) where {F,B,C}
mode = forward_withprimal(backend)
f_and_df = get_f_and_df(f, backend, mode, Val(B))
tx_sametype = map(Fix1(convert, typeof(x)), tx)
x_and_tx = BatchDuplicated(x, tx_sametype)
x_and_tx = BatchDuplicated(x, tx)
annotated_contexts = translate(backend, mode, Val(B), contexts...)
ty, y = autodiff(mode, f_and_df, x_and_tx, annotated_contexts...)
return y, values(ty)
Expand All @@ -54,8 +53,8 @@ function DI.pushforward(
) where {F,C}
mode = forward_noprimal(backend)
f_and_df = get_f_and_df(f, backend, mode)
dx_sametype = convert(typeof(x), only(tx))
x_and_dx = Duplicated(x, dx_sametype)
dx = only(tx)
x_and_dx = Duplicated(x, dx)
annotated_contexts = translate(backend, mode, Val(1), contexts...)
dy = only(autodiff(mode, f_and_df, x_and_dx, annotated_contexts...))
return (dy,)
Expand All @@ -71,8 +70,7 @@ function DI.pushforward(
) where {F,B,C}
mode = forward_noprimal(backend)
f_and_df = get_f_and_df(f, backend, mode, Val(B))
tx_sametype = map(Fix1(convert, typeof(x)), tx)
x_and_tx = BatchDuplicated(x, tx_sametype)
x_and_tx = BatchDuplicated(x, tx)
annotated_contexts = translate(backend, mode, Val(B), contexts...)
ty = only(autodiff(mode, f_and_df, x_and_tx, annotated_contexts...))
return values(ty)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ function DI.value_and_pushforward(
) where {F,C}
mode = forward_noprimal(backend)
f!_and_df! = get_f_and_df(f!, backend, mode)
dx_sametype = convert(typeof(x), only(tx))
dy_sametype = make_zero(y)
x_and_dx = Duplicated(x, dx_sametype)
y_and_dy = Duplicated(y, dy_sametype)
dx = only(tx)
dy = make_zero(y)
x_and_dx = Duplicated(x, dx)
y_and_dy = Duplicated(y, dy)
annotated_contexts = translate(backend, mode, Val(1), contexts...)
autodiff(mode, f!_and_df!, Const, y_and_dy, x_and_dx, annotated_contexts...)
return y, (dy_sametype,)
return y, (dy,)
end

function DI.value_and_pushforward(
Expand All @@ -42,13 +42,12 @@ function DI.value_and_pushforward(
) where {F,B,C}
mode = forward_noprimal(backend)
f!_and_df! = get_f_and_df(f!, backend, mode, Val(B))
tx_sametype = map(Fix1(convert, typeof(x)), tx)
ty_sametype = ntuple(_ -> make_zero(y), Val(B))
x_and_tx = BatchDuplicated(x, tx_sametype)
y_and_ty = BatchDuplicated(y, ty_sametype)
ty = ntuple(_ -> make_zero(y), Val(B))
x_and_tx = BatchDuplicated(x, tx)
y_and_ty = BatchDuplicated(y, ty)
annotated_contexts = translate(backend, mode, Val(B), contexts...)
autodiff(mode, f!_and_df!, Const, y_and_ty, x_and_tx, annotated_contexts...)
return y, ty_sametype
return y, ty
end

function DI.pushforward(
Expand Down Expand Up @@ -76,13 +75,10 @@ function DI.value_and_pushforward!(
) where {F,B,C}
mode = forward_noprimal(backend)
f!_and_df! = get_f_and_df(f!, backend, mode, Val(B))
tx_sametype = map(Fix1(convert, typeof(x)), tx)
ty_sametype = map(Fix1(convert, typeof(y)), ty)
x_and_tx = BatchDuplicated(x, tx_sametype)
y_and_ty = BatchDuplicated(y, ty_sametype)
x_and_tx = BatchDuplicated(x, tx)
y_and_ty = BatchDuplicated(y, ty)
annotated_contexts = translate(backend, mode, Val(B), contexts...)
autodiff(mode, f!_and_df!, Const, y_and_ty, x_and_tx, annotated_contexts...)
foreach(copyto_if_different_addresses!, ty, ty_sametype)
return y, ty
end

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@ function seeded_autodiff_thunk(
forward, reverse = autodiff_thunk(rmode, FA, RA, typeof.(args)...)
tape, result, shadow_result = forward(f, args...)
if RA <: Active
dresult_righttype = convert(typeof(result), dresult)
dinputs = only(reverse(f, args..., dresult_righttype, tape))
dinputs = only(reverse(f, args..., dresult, tape))
else
shadow_result .+= dresult # TODO: generalize beyond arrays
dinputs = only(reverse(f, args..., tape))
Expand All @@ -32,8 +31,7 @@ function batch_seeded_autodiff_thunk(
forward, reverse = autodiff_thunk(rmode_rightwidth, FA, RA, typeof.(args)...)
tape, result, shadow_results = forward(f, args...)
if RA <: Active
dresults_righttype = map(Fix1(convert, typeof(result)), dresults)
dinputs = only(reverse(f, args..., dresults_righttype, tape))
dinputs = only(reverse(f, args..., dresults, tape))
else
foreach(shadow_results, dresults) do d0, d
d0 .+= d # use recursive_add here?
Expand Down Expand Up @@ -141,13 +139,12 @@ function DI.value_and_pullback!(
mode = reverse_split_withprimal(backend)
f_and_df = force_annotation(get_f_and_df(f, backend, mode))
RA = guess_activity(typeof(prep.y_example), mode)
dx_righttype = convert(typeof(x), only(tx))
make_zero!(dx_righttype)
dx = only(tx)
make_zero!(dx)
annotated_contexts = translate(backend, mode, Val(1), contexts...)
_, result = seeded_autodiff_thunk(
mode, only(ty), f_and_df, RA, Duplicated(x, dx_righttype), annotated_contexts...
mode, only(ty), f_and_df, RA, Duplicated(x, dx), annotated_contexts...
)
copyto_if_different_addresses!(only(tx), dx_righttype)
return result, tx
end

Expand All @@ -163,13 +160,11 @@ function DI.value_and_pullback!(
mode = reverse_split_withprimal(backend)
f_and_df = force_annotation(get_f_and_df(f, backend, mode, Val(B)))
RA = batchify_activity(guess_activity(typeof(prep.y_example), mode), Val(B))
tx_righttype = map(Fix1(convert, typeof(x)), tx)
make_zero!(tx_righttype)
make_zero!(tx)
annotated_contexts = translate(backend, mode, Val(B), contexts...)
_, result = batch_seeded_autodiff_thunk(
mode, ty, f_and_df, RA, BatchDuplicated(x, tx_righttype), annotated_contexts...
mode, ty, f_and_df, RA, BatchDuplicated(x, tx), annotated_contexts...
)
foreach(copyto_if_different_addresses!, tx, tx_righttype)
return result, tx
end

Expand All @@ -187,10 +182,73 @@ end

## Gradient

### Without preparation
function DI.prepare_gradient(
f::F, ::AutoEnzyme{<:Union{ReverseMode,Nothing}}, x, contexts::Vararg{DI.Context,C}
) where {F,C}
return DI.NoGradientPrep()
end

### Enzyme gradient API (only constants)

function DI.gradient(
f::F,
::DI.NoGradientPrep,
backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}},
x,
contexts::Vararg{DI.Constant,C},
) where {F,C}
mode = reverse_noprimal(backend)
f_and_df = get_f_and_df(f, backend, mode)
annotated_contexts = translate(backend, mode, Val(1), contexts...)
grads = gradient(mode, f_and_df, x, annotated_contexts...)
return first(grads)
end

function DI.value_and_gradient(
f::F,
::DI.NoGradientPrep,
backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}},
x,
contexts::Vararg{DI.Constant,C},
) where {F,C}
mode = reverse_withprimal(backend)
f_and_df = get_f_and_df(f, backend, mode)
annotated_contexts = translate(backend, mode, Val(1), contexts...)
grads, result = gradient(mode, f_and_df, x, annotated_contexts...)
return result, first(grads)
end

function DI.gradient!(
f::F,
grad,
::DI.NoGradientPrep,
backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}},
x,
) where {F}
mode = reverse_noprimal(backend)
f_and_df = get_f_and_df(f, backend, mode)
gradient!(mode, grad, f_and_df, x)
return grad
end

function DI.value_and_gradient!(
f::F,
grad,
::DI.NoGradientPrep,
backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}},
x,
) where {F}
mode = reverse_withprimal(backend)
f_and_df = get_f_and_df(f, backend, mode)
_, result = gradient!(mode, grad, f_and_df, x)
return result, grad
end

### Generic

function DI.gradient(
f::F,
::DI.NoGradientPrep,
backend::AutoEnzyme{<:Union{ReverseMode,Nothing}},
x,
contexts::Vararg{DI.Context,C},
Expand All @@ -213,6 +271,7 @@ end

function DI.value_and_gradient(
f::F,
::DI.NoGradientPrep,
backend::AutoEnzyme{<:Union{ReverseMode,Nothing}},
x,
contexts::Vararg{DI.Context,C},
Expand All @@ -233,73 +292,34 @@ function DI.value_and_gradient(
end
end

### With preparation

struct EnzymeGradientPrep{G} <: DI.GradientPrep
grad_righttype::G
end

function DI.prepare_gradient(
f::F, ::AutoEnzyme{<:Union{ReverseMode,Nothing}}, x, contexts::Vararg{DI.Context,C}
) where {F,C}
grad_righttype = make_zero(x)
return EnzymeGradientPrep(grad_righttype)
end

function DI.gradient(
f::F,
::EnzymeGradientPrep,
backend::AutoEnzyme{<:Union{ReverseMode,Nothing}},
x,
contexts::Vararg{DI.Context,C},
) where {F,C}
return DI.gradient(f, backend, x, contexts...)
end

function DI.gradient!(
f::F,
grad,
prep::EnzymeGradientPrep,
::DI.NoGradientPrep,
backend::AutoEnzyme{<:Union{ReverseMode,Nothing}},
x,
contexts::Vararg{DI.Context,C},
) where {F,C}
mode = reverse_noprimal(backend)
f_and_df = get_f_and_df(f, backend, mode)
grad_righttype = grad isa typeof(x) ? grad : prep.grad_righttype
make_zero!(grad_righttype)
annotated_contexts = translate(backend, mode, Val(1), contexts...)
autodiff(mode, f_and_df, Active, Duplicated(x, grad_righttype), annotated_contexts...)
copyto_if_different_addresses!(grad, grad_righttype)
make_zero!(grad)
autodiff(mode, f_and_df, Active, Duplicated(x, grad), annotated_contexts...)
return grad
end

function DI.value_and_gradient(
f::F,
::EnzymeGradientPrep,
backend::AutoEnzyme{<:Union{ReverseMode,Nothing}},
x,
contexts::Vararg{DI.Context,C},
) where {F,C}
return DI.value_and_gradient(f, backend, x, contexts...)
end

function DI.value_and_gradient!(
f::F,
grad,
prep::EnzymeGradientPrep,
::DI.NoGradientPrep,
backend::AutoEnzyme{<:Union{ReverseMode,Nothing}},
x,
contexts::Vararg{DI.Context,C},
) where {F,C}
mode = reverse_withprimal(backend)
f_and_df = get_f_and_df(f, backend, mode)
grad_righttype = grad isa typeof(x) ? grad : prep.grad_righttype
make_zero!(grad_righttype)
annotated_contexts = translate(backend, mode, Val(1), contexts...)
_, y = autodiff(
mode, f_and_df, Active, Duplicated(x, grad_righttype), annotated_contexts...
)
copyto_if_different_addresses!(grad, grad_righttype)
make_zero!(grad)
_, y = autodiff(mode, f_and_df, Active, Duplicated(x, grad), annotated_contexts...)
return y, grad
end
Loading

2 comments on commit 1df5621

@gdalle
Copy link
Member Author

@gdalle gdalle commented on 1df5621 Feb 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register subdir=DifferentiationInterface

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/125257

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a DifferentiationInterface-v0.6.42 -m "<description of version>" 1df562180bdcc3e91c885aa5f4162a0be2ced850
git push origin DifferentiationInterface-v0.6.42

Please sign in to comment.