Skip to content

Commit

Permalink
Don't load Yota at all (#166)
Browse files Browse the repository at this point in the history
* Update runtests.jl

* Update runtests.jl

* Update Project.toml

* Update runtests.jl

* Update rules.jl

* Update destructure.jl

* test is no longer broken

* Update index.md
  • Loading branch information
mcabbott authored Feb 6, 2024
1 parent 88b527c commit 6473c45
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 33 deletions.
4 changes: 1 addition & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,13 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
ChainRulesCore = "1"
Functors = "0.4"
Statistics = "1"
Yota = "0.8.2"
Zygote = "0.6.40"
julia = "1.6"

[extras]
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Yota = "cd998857-8626-517d-b929-70ad188a48f0"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "StaticArrays", "Yota", "Zygote"]
test = ["Test", "StaticArrays", "Zygote"]
19 changes: 0 additions & 19 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,25 +80,6 @@ Zygote's "implicit" mode `gradient(() -> loss(...), Flux.params(model))` -- see
[Zygote's documentation](https://fluxml.ai/Zygote.jl/dev/#Explicit-and-Implicit-Parameters-1) for more about this difference.


## Usage with [Yota.jl](https://github.com/dfdx/Yota.jl)

Yota is another modern automatic differentiation package, an alternative to Zygote.

Its main function is `Yota.grad`, which returns the loss as well as the gradient (like `Zygote.withgradient`)
but also returns a gradient component for the loss function.
To extract what Optimisers.jl needs, you can write (for the Flux model above):

```julia
using Yota

loss, (∇function, ∇model, ∇image) = Yota.grad(model, image) do m, x
sum(m(x)
end;

# Or else, this may save computing ∇image:
loss, (_, ∇model) = grad(m -> sum(m(image)), model);
```
## Usage with [Lux.jl](https://github.com/avik-pal/Lux.jl)

The main design difference of Lux from Flux is that the tree of parameters is separate from
Expand Down
6 changes: 3 additions & 3 deletions test/destructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,14 @@ end
sum(gradient(m -> sum(destructure(m)[1])^3, (v, [4,5,6.0]))[1][1])
end[1] == [378, 378, 378]

@test_broken gradient([1,2,3.0]) do v
VERSION >= v"1.10" && @test gradient([1,2,3.0]) do v
sum(abs2, gradient(m -> sum(abs2, destructure(m)[1]), (x = v, y = sin, z = [4,5,6.0]))[1][1])
end[1] [8,16,24]
# Zygote error in (::typeof(∂(canonicalize)))(Δ::NamedTuple{(:backing,), Tuple{NamedTuple{(:x, :y, :z)
# Diffractor error in perform_optic_transform
end

VERSION < v"1.9-" && @testset "using Yota" begin
false && @testset "using Yota" begin
@test Yota_gradient(m -> destructure(m)[1][1], m1)[1] == [1,0,0]
@test Yota_gradient(m -> destructure(m)[1][2], m2)[1] == ([0,1,0], [0,0,0])
@test Yota_gradient(m -> destructure(m)[1][3], (m1, m1))[1] == ([0,0,1], nothing)
Expand Down Expand Up @@ -175,7 +175,7 @@ end
# Zygote.@adjoint Tangent{T,B}(x::NamedTuple) where {T,B<:NamedTuple} = Tangent{T,B}(x), dx -> (dx,)
end

VERSION < v"1.9-" && @testset "using Yota" begin
false && @testset "using Yota" begin
re1 = destructure(m1)[2]
@test Yota_gradient(x -> re1(x)[1], rand(3))[1] == [1,0,0]
re2 = destructure(m2)[2]
Expand Down
4 changes: 2 additions & 2 deletions test/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ end
end
end

VERSION < v"1.9-" && @testset "using Yota" begin
false && @testset "using Yota" begin
@testset "$(name(o))" for o in RULES
w′ = (abc == rand(3, 3), β = rand(3, 3), γ = rand(3)), d == rand(3), ε = eps))
w = (abc == 5rand(3, 3), β = rand(3, 3), γ = rand(3)), d == rand(3), ε = eps))
Expand Down Expand Up @@ -266,4 +266,4 @@ end

tree, x4 = Optimisers.update(tree, x3, g4)
@test x4 x3
end
end
15 changes: 9 additions & 6 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Optimisers
using ChainRulesCore, Functors, StaticArrays, Zygote, Yota
using ChainRulesCore, Functors, StaticArrays, Zygote
using LinearAlgebra, Statistics, Test, Random
using Optimisers: @.., @lazy
using Base.Broadcast: broadcasted, instantiate, Broadcasted
Expand Down Expand Up @@ -38,12 +38,15 @@ function Optimisers.apply!(o::BiRule, state, x, dx, dx2)
return state, dx
end

# Make Yota's output look like Zygote's:
# if VERSION < v"1.9-"
# using Yota
# end
# # Make Yota's output look like Zygote's:

Yota_gradient(f, xs...) = map(y2z, Base.tail(Yota.grad(f, xs...)[2]))
y2z(::AbstractZero) = nothing # we don't care about different flavours of zero
y2z(t::Tangent) = map(y2z, ChainRulesCore.backing(canonicalize(t))) # namedtuples!
y2z(x) = x
# Yota_gradient(f, xs...) = map(y2z, Base.tail(Yota.grad(f, xs...)[2]))
# y2z(::AbstractZero) = nothing # we don't care about different flavours of zero
# y2z(t::Tangent) = map(y2z, ChainRulesCore.backing(canonicalize(t))) # namedtuples!
# y2z(x) = x

@testset verbose=true "Optimisers.jl" begin
@testset verbose=true "Features" begin
Expand Down

0 comments on commit 6473c45

Please sign in to comment.