Skip to content

Commit

Permalink
remove leaf.frozen field
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Aug 28, 2022
1 parent e17e474 commit f046185
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 6 deletions.
4 changes: 2 additions & 2 deletions src/adjust.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ adjust(tree; kw...) = map(st -> adjust(st; kw...), tree)
adjust(::Nothing, ::Real) = nothing
adjust(::Nothing; kw...) = nothing

adjust(ℓ::Leaf, eta::Real) = .frozen ?: Leaf(adjust(ℓ.rule, eta), ℓ.state, ℓ.frozen)
adjust(ℓ::Leaf; kw...) = .frozen ?: Leaf(adjust(ℓ.rule; kw...), ℓ.state, ℓ.frozen)
adjust(ℓ::Leaf, eta::Real) = Leaf(adjust(ℓ.rule, eta), ℓ.state)
adjust(ℓ::Leaf; kw...) = Leaf(adjust(ℓ.rule; kw...), ℓ.state)


"""
Expand Down
6 changes: 2 additions & 4 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ abstract type AbstractRule end
mutable struct Leaf{R,S} # mutable so that its identity encodes parameter sharing
rule::R
state::S
frozen::Bool # mutability also allows this flag to be changed
end

@functor Leaf
Expand All @@ -25,7 +24,7 @@ function setup(rule::AbstractRule, model)
# Rely on Functors to identify shared arrays, they will share a Leaf in this tree:
tree = fmapstructure(model, exclude = isnumeric) do x
cnt[] += 1
Leaf(rule, init(rule, x), false)
Leaf(rule, init(rule, x))
end
cnt[] == 0 && @warn "setup found no parameters in the given model"
tree
Expand All @@ -35,7 +34,7 @@ function Base.show(io::IO, ℓ::Leaf) # show method is mostly to hide its long
ioc = IOContext(io, :compact => true)
print(ioc, "Leaf(", ℓ.rule, ", ")
show(ioc, ℓ.state)
print(ioc, ", ", ℓ.frozen, ")")
print(ioc, ")")
end

###
Expand All @@ -49,7 +48,6 @@ function update!(tree, model, grad)
grads!(dict, tree, model, grad)
# Second walk is to update the model. The walk taken follows Leaf identity
newmodel = fmap(tree, model; exclude =->isa Leaf, walk = _second_walk, cache = LeafCache()) do ℓ, x
.frozen && return x
haskey(dict, ℓ) || return x # no gradient seen, nothing to do
s′, x̄′ = apply!(ℓ.rule, ℓ.state, x, dict[ℓ])
.state = s′ # to get state out of here, rely on mutability of Leaf
Expand Down

0 comments on commit f046185

Please sign in to comment.