Skip to content

Commit

Permalink
add path option to trainables (#174)
Browse files Browse the repository at this point in the history
* add path=true

* fix

* fix

* fix docs

* fix docs

* update doc workflow
  • Loading branch information
CarloLucibello authored Apr 9, 2024
1 parent a87ffd5 commit 8ca6ce0
Show file tree
Hide file tree
Showing 14 changed files with 268 additions and 137 deletions.
24 changes: 0 additions & 24 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,27 +53,3 @@ jobs:
file: lcov.info
continue-on-error: ${{ matrix.julia-version == 'nightly' }}

docs:
name: Documentation
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: julia-actions/setup-julia@v1
with:
version: '1.6'
- run: |
julia --project=docs -e '
using Pkg
Pkg.develop(PackageSpec(path=pwd()))
Pkg.instantiate()'
- run: |
julia --color=yes --project=docs/ -e '
using Optimisers
using Documenter
using Documenter: doctest
DocMeta.setdocmeta!(Optimisers, :DocTestSetup, :(using Optimisers); recursive = true)
doctest(Optimisers)'
- run: julia --project=docs docs/make.jl
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }}
7 changes: 7 additions & 0 deletions .github/workflows/dependabot.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates
version: 2
updates:
- package-ecosystem: "github-actions"
directory: "/" # Location of package manifests
schedule:
interval: "weekly"
28 changes: 28 additions & 0 deletions .github/workflows/documentation.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
name: Documentation

on:
push:
branches:
- master # update to match your development branch (master, main, dev, trunk, ...)
tags: '*'
pull_request:

jobs:
build:
permissions:
contents: write
statuses: write
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v1
with:
version: '1.10'
- uses: julia-actions/cache@v1
- name: Install dependencies
run: julia --project=docs/ -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()'
- name: Build and deploy
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # If authenticating with GitHub Actions token
DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }} # If authenticating with SSH deploy key
run: julia --project=docs/ docs/make.jl
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
ChainRulesCore = "1"
Functors = "0.4"
Functors = "0.4.9"
Statistics = "1"
Zygote = "0.6.40"
julia = "1.6"
Expand Down
Binary file modified docs/.DS_Store
Binary file not shown.
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[deps]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
6 changes: 4 additions & 2 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
using Documenter, Optimisers, Zygote, StaticArrays, Functors

DocMeta.setdocmeta!(Optimisers, :DocTestSetup, :(using Optimisers); recursive = true)
DocMeta.setdocmeta!(Optimisers, :DocTestSetup, :(using Optimisers, Functors); recursive = true)
DocMeta.setdocmeta!(Functors, :DocTestSetup, :(using Functors); recursive = true)

makedocs(modules = [Optimisers],
makedocs(modules = [Optimisers, Functors],
doctest = false,
sitename = "Optimisers.jl",
pages = ["Home" => "index.md",
Expand All @@ -13,6 +14,7 @@ makedocs(modules = [Optimisers],
assets = ["assets/flux.css"],
prettyurls = get(ENV, "CI", nothing) == "true"
),
checkdocs = :none, # don't check that Functors' docstrings are all reported here
)

deploydocs(
Expand Down
14 changes: 14 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
```@meta
CollapsedDocStrings = true
```

## Optimisation Rules

Expand Down Expand Up @@ -72,3 +75,14 @@ Optimisers.@lazy
Optimisers.adjust(::AbstractRule, ::Real)
Optimisers.@def
```

## KeyPath

A `KeyPath` is a sequence of keys that can be used to access a value within a nested structure.
It is defined in Functors.jl and re-exported by Optimisers.jl here for convenience.

```@docs
Functors.KeyPath
Functors.haskeypath
Functors.getkeypath
```
7 changes: 6 additions & 1 deletion src/Optimisers.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
module Optimisers

using Functors: functor, fmap, isleaf, @functor, fmapstructure, children, AbstractWalk
using Functors: functor, fmap, fmap_with_path,
KeyPath, haskeypath, getkeypath,
isleaf, @functor, fmapstructure, children, AbstractWalk
using LinearAlgebra

include("interface.jl")
export AbstractRule

include("utils.jl")

include("adjust.jl")

include("destructure.jl")
export destructure

include("trainables.jl")
export trainables
export KeyPath, haskeypath, getkeypath # from Functors.jl

include("rules.jl")
export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp,
Expand Down
2 changes: 1 addition & 1 deletion src/destructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ end

struct TrainableStructWalk <: AbstractWalk end

(::TrainableStructWalk)(recurse, x) = map(recurse, _trainable(x))
(::TrainableStructWalk)(recurse, x) = mapvalue(recurse, _trainable(x))

_vec(x::Number) = LinRange(x,x,1)
_vec(x::AbstractArray) = vec(x)
Expand Down
16 changes: 4 additions & 12 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ function _setup(rule, x; cache)
cache[x] =
end
else
valuemap(xᵢ -> _setup(rule, xᵢ; cache), _trainable(x))
mapvalue(xᵢ -> _setup(rule, xᵢ; cache), _trainable(x))
end
end

Expand Down Expand Up @@ -82,7 +82,7 @@ function _update!(tree, x; grads, params)
haskey(params, (tree,x)) && return params[(tree,x)]
isbits(tree) && return x # means () is not cached, and also (((),),)
x′, re = functor(x)
x′′ = re(valuemap((tᵢ, xᵢ) -> _update!(tᵢ, xᵢ; grads, params), tree, x′))
x′′ = re(mapvalue((tᵢ, xᵢ) -> _update!(tᵢ, xᵢ; grads, params), tree, x′))
if ismutable(x′′)
params[(tree,x)] = x′′
else # no ties to preserve between immutable structs, right?
Expand Down Expand Up @@ -115,7 +115,7 @@ function _grads!(dict::IdDict, tree, x, x̄s...)
# functor(typeof(tree), base(x̄)), for things like Transpose
x̄s′ = map(x̄ -> functor(typeof(x), base(x̄))[1], x̄s)
x′, _ = functor(typeof(x), x)
valueforeach((tᵢ, xᵢ, x̄sᵢ...) -> _grads!(dict, tᵢ, xᵢ, x̄sᵢ...), tree, x′, x̄s′...)
foreachvalue((tᵢ, xᵢ, x̄sᵢ...) -> _grads!(dict, tᵢ, xᵢ, x̄sᵢ...), tree, x′, x̄s′...)
end

# default all rules to first order calls
Expand Down Expand Up @@ -172,22 +172,14 @@ _trainable(x) = _trainable(functor(x)[1], trainable(x))
_trainable(ch::NamedTuple, tr::NamedTuple) = merge(map(_ -> nothing, ch), tr)
_trainable(ch::Tuple{Vararg{Any,N}}, tr::Tuple{Vararg{Any,N}}) where N = tr
_trainable(ch::AbstractArray, tr::AbstractArray) = tr
_trainable(ch::Dict, tr::Dict) = merge(valuemap(_ -> nothing, ch), tr)
_trainable(ch::Dict, tr::Dict) = merge(mapvalue(_ -> nothing, ch), tr)

function _trainable(ch::NamedTuple, tr::Tuple) # for old Flux-style no-names tuple
@warn "trainable(x) should now return a NamedTuple with the field names, not a Tuple" maxlog=3
map(c -> c in tr ? c : nothing, ch)
end


valuemap(f, x...) = map(f, x...)
valuemap(f, x::Dict, ys...) = Dict(k => f(v, (get(y, k, nothing) for y in ys)...) for (k,v) in x)
valueforeach(f, x...) = foreach(f, x...)
valueforeach(f, x::Dict, ys...) = foreach(pairs(x)) do (k, v)
f(v, (get(y, k, nothing) for y in ys)...)
end


###
### rule definition helpers
###
Expand Down
81 changes: 73 additions & 8 deletions src/trainables.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@

"""
trainables(x)
trainables(x, path = false)
Return a list over all the trainable parameters in `x`, that is all the numerical
Return an iterable over all the trainable parameters in `x`, that is all the numerical
arrays (see [`isnumeric`](@ref Optimisers.isnumeric)) which are reachable through [`trainable`](@ref Optimisers.trainable).
Parameters appearing multiple times in the model (tied weights) will be present only once in the output.
If `path = false`, the output is a list of numerical arrays.
If `path = true`, the output is a list of `(KeyPath, AbstractArray)` pairs, where [`KeyPath`](@ref) is a type
representing the path to the array in the original structure.
See also [`destructure`](@ref) for a similar operation that returns a single flat vector instead.
# Examples
Expand All @@ -33,27 +38,87 @@ julia> trainables(x)
2-element Vector{AbstractArray}:
[1.0, 2.0]
[3.0]
```
```jldoctest
julia> x = (a = [1.0,2.0], b = (Dict("c" => [3.0, 4.0], "d" => 5.0), [6.0,7.0]));
julia> for (kp, y) in trainables(x, path = true)
println(kp, " => ", y)
end
KeyPath(:a,) => [1.0, 2.0]
KeyPath(:b, 1, "c") => [3.0, 4.0]
KeyPath(:b, 2) => [6.0, 7.0]
julia> getkeypath(x, KeyPath(:b, 1, "c"))
2-element Vector{Float64}:
3.0
4.0
```
"""
function trainables(x)
function trainables(x; path = false)
if path
return _trainables_with_path(x)
else
return _trainables(x)
end
end


function _trainables(x)
arrays = AbstractArray[]
exclude(x) = Optimisers.isnumeric(x)
fmap(x; exclude, walk = Optimisers.TrainableStructWalk()) do y
fmap(x; exclude = isnumeric, walk = TrainableStructWalk()) do y
push!(arrays, y)
return y
end
return arrays
end

function ∇trainables(x, Δ)
exclude(x) = Optimisers.isnumeric(x)
i = 0
return fmapstructure(x; exclude, walk = TrainableStructWalk()) do _
return fmapstructure(x; exclude = isnumeric, walk = TrainableStructWalk()) do _
return Δ[i+=1]
end
end

function ChainRulesCore.rrule(::typeof(trainables), x)
function ChainRulesCore.rrule(::typeof(_trainables), x)
y = trainables(x)
trainables_back(Δ) = (NoTangent(), ∇trainables(x, unthunk(Δ)))
return y, trainables_back
end

function _trainables_with_path(x)
named_params = []
exclude(kp, x) = isnumeric(x)
fmap_with_path(x; exclude, walk = TrainableStructWalkWithPath()) do kp, y
push!(named_params, (kp, y))
return y
end
return named_params
end

struct TrainableStructWalkWithPath <: AbstractWalk end

function (::TrainableStructWalkWithPath)(recurse, kp::KeyPath, x)
x_children = trainable(x)
kps = mapkey(c -> KeyPath(kp, c), x_children)
return mapvalue(recurse, kps, x_children)
end

function ChainRulesCore.rrule(::typeof(_trainables_with_path), x)
y = _trainables_with_path(x)
trainables_with_path_back(Δ) = (NoTangent(), ∇trainables_with_path(x, unthunk(Δ)))
return y, trainables_with_path_back
end

function ∇trainables_with_path(x, Δ)
i = 0
return fmapstructure(x; exclude = isnumeric, walk = TrainableStructWalk()) do _
Δi = Δ[i+=1]
if isnothing(Δi)
return nothing
else
return Δi[2]
end
end
end
15 changes: 15 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

mapvalue(f, x...) = map(f, x...)
mapvalue(f, x::Dict, ys...) = Dict(k => f(v, (get(y, k, nothing) for y in ys)...) for (k,v) in x)

mapkey(f, x::NamedTuple{Ks}) where Ks = NamedTuple{Ks}(map(f, Ks))
mapkey(f, x::Dict) = Dict(k => f(k) for k in keys(x))
mapkey(f, x::Tuple) = ntuple(i -> f(i), length(x))
mapkey(f, x::AbstractArray) = [f(i) for i=1:length(x)]

foreachvalue(f, x...) = foreach(f, x...)

foreachvalue(f, x::Dict, ys...) = foreach(pairs(x)) do (k, v)
f(v, (get(y, k, nothing) for y in ys)...)
end

Loading

0 comments on commit 8ca6ce0

Please sign in to comment.