Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use the cache less often #39

Merged
merged 7 commits into from
Oct 20, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Functors"
uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
authors = ["Mike J Innes <[email protected]>"]
version = "0.3.0"
version = "0.4.0"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
11 changes: 6 additions & 5 deletions src/Functors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ export @functor, @flexiblefunctor, fmap, fmapstructure, fcollect
include("functor.jl")
include("base.jl")


###
### Docstrings for basic functionality
###
Expand Down Expand Up @@ -132,19 +131,21 @@ Any[23, (45,), (x = 6//7, y = ())]
[8, 9]
(a = nothing, b = nothing, c = nothing)

julia> twice = [1, 2];
julia> twice = [1, 2]; # println only acts once on this

julia> fmap(println, (i = twice, ii = 34, iii = [5, 6], iv = (twice, 34), v = 34.0))
[1, 2]
34
[5, 6]
34
34.0
(i = nothing, ii = nothing, iii = nothing, iv = (nothing, nothing), v = nothing)
```

If the same node (same according to `===`) appears more than once,
it will only be handled once, and only be transformed once with `f`.
Thus the result will also have this relationship.
Mutable objects which appear more than once are only handled once (by caching `f(x)` in an `IdDict`).
Thus the relationship `x.i === x.iv[1]` will be preserved.
An immutable object which appears twice is not stored in the cache, thus `f(34)` will be called twice,
and the results will agree only if `f` is pure.

By default, `Tuple`s, `NamedTuple`s, and some other container-like types in Base have
children to recurse into. Arrays of numbers do not.
Expand Down
64 changes: 57 additions & 7 deletions src/functor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,39 @@ function _default_walk(f, x)
re(map(f, func))
end

usecache(::AbstractDict, x) = isleaf(x) ? anymutable(x) : ismutable(x)
usecache(::Nothing, x) = false

# function _anymutable(x::T) where {T}
# ismutable(x) && return true
# fs = fieldnames(T)
# isempty(fs) && return false
# return any(f -> anymutable(getfield(x, f)), fs)
# end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this fail to constant fold sometimes?

Otherwise LGTM. About the weird cases, we could argue it's more conservative to not cache in both. A false positive seems much worse than a false negative here IMO. Asking uses of a higher-level isleaf to take on additional responsibility for caching is also fine. Incidentally, this is why I think extracting out caching from Functors and making callbacks handle memoization themselves would be nice.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It fails to be instant on surprisingly simple functions, I didn't try to dig into why:

julia> @btime fun_anymutable((x=(1,2), y=3))
  min 36.500 ns, mean 38.715 ns (1 allocation, 32 bytes)
false

julia> @btime gen_anymutable((x=(1,2), y=3))
  min 0.001 ns, mean 0.014 ns (0 allocations)
false

Perhaps more surprisingly, the generated one is also not free e.g. here:

julia> @btime fun_anymutable($(Metalhead.ResNet()))
  min 275.685 ns, mean 323.217 ns (9 allocations, 320 bytes)
true

julia> @btime gen_anymutable($(Metalhead.ResNet()))
  min 147.536 ns, mean 161.010 ns (1 allocation, 32 bytes)
true

That contains Chain([...]) which... should just stop the recursion?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Smaller example, in which the number of layers seems to matter:

julia> model = Chain(
         Conv((3, 3), 1 => 16),                # 160 parameters
         Conv((3, 3), 16 => 16),               # 2_320 parameters
         Conv((3, 3), 16 => 32),               # 4_640 parameters
         Conv((3, 3), 32 => 32),               # 9_248 parameters
         Conv((3, 3), 32 => 64),               # 18_496 parameters
         Conv((3, 3), 64 => 64),               # 36_928 parameters
         Dense(16384 => 10),                   # 163_850 parameters
       );

julia> @btime fun_anymutable($model)
  min 327.851 ns, mean 448.404 ns (10 allocations, 3.17 KiB)
true

julia> @btime gen_anymutable($model)
  min 215.463 ns, mean 238.700 ns (8 allocations, 608 bytes)
true

julia> model = Chain(
         Conv((3, 3), 1 => 16),                # 160 parameters
         Conv((3, 3), 16 => 16),               # 2_320 parameters
         # Conv((3, 3), 16 => 32),               # 4_640 parameters
         # Conv((3, 3), 32 => 32),               # 9_248 parameters
         # Conv((3, 3), 32 => 64),               # 18_496 parameters
         Conv((3, 3), 64 => 64),               # 36_928 parameters
         Dense(16384 => 10),                   # 163_850 parameters
       );

julia> @btime fun_anymutable($model)
  min 344.818 ns, mean 391.967 ns (10 allocations, 1.75 KiB)
true

julia> @btime gen_anymutable($model)
  min 0.001 ns, mean 0.014 ns (0 allocations)
true

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the Metalhead example at least, that one allocation is coming from https://github.com/FluxML/Metalhead.jl/blob/7827ca6ec4ef7c5e07d04cd6d84a1a3b11289dc0/src/convnets/resnets/resnet.jl#L17. For the longer Chain, Cthulhu tells me

┌ Info: Inference didn't cache this call information because of imprecise analysis due to recursion:
└ Cthulhu nevertheless is trying to descend into it for further inspection.

If I add a guard against the possible missing from any in gen_anymutable and assert the return value like so:

@generated function gen_anymutable(x::T) where {T}
  ismutabletype(T) && return true
  fs = fieldnames(T)
  isempty(fs) && return false
  subs =  [:(gen_anymutable(getfield(x, $f))) for f in QuoteNode.(fs)]
  return :(coalesce(|($(subs...)), false)::Bool)
end

That eliminates all but 6 of the allocations. I believe these correspond to the 6 Conv layers because the check on the Dense layer appears to be fully const folded (why only the Dense? Not sure).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh nice. Just the ::Bool seems to be enough, and should be safe I think.

Weirdly it is instant for 5 and 7 conv layers, only exactly 6 causes it to fail & take 100ns.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is absolutely bizarre. It also works for 6 Conv layers if I remove the final Dense and up to at least 32 with/without. Granted, this is on nightly—I couldn't get close to your timings on 1.8.2 IIRC.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, less bizarre. Trying to simplify things a bit, it looks like the first call is always taking a hit here, but subsequent calls are fine.

using BenchmarkTools

struct Conv{N,M,F,A,V}
  σ::F
  weight::A
  bias::V
  stride::NTuple{N,Int}
  pad::NTuple{M,Int}
  dilation::NTuple{N,Int}
  groups::Int
end

struct Dense{F, M, B}
  weight::M
  bias::B
  σ::F
end

@generated function anymutable(x::T) where {T}
  ismutabletype(T) && return true
  fs = fieldnames(T)
  isempty(fs) && return false
  subs =  [:(anymutable(getfield(x, $f))) for f in QuoteNode.(fs)]
  return :(|($(subs...))::Bool)
end

function test()
  for N in (5, 6, 7)
    @info N
    layers = ntuple(_ -> Conv(identity, ones(1), ones(1), (1,), (1,), (1,), 1), N)
    layers = (layers..., Dense(ones(1), ones(1), identity))
    @btime anymutable($layers)
  end
end

test()

Perhaps that has something to do with the generated function?

Copy link
Member Author

@mcabbott mcabbott Oct 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's weird.

I don't suggest doing this, but this does seem to compile away:

julia> Base.@assume_effects :total function fun_anymutable3(x::T) where {T}
          ismutable(x) && return true
          fs = fieldnames(T)
          isempty(fs) && return false
          return any(f -> fun_anymutable3(getfield(x, f)), fs)::Bool
       end
fun_anymutable3 (generic function with 1 method)

julia> function test_3()
         for N in (5, 6, 7)
           @info N
           layers = ntuple(_ -> Conv(identity, ones(1), ones(1), (1,), (1,), (1,), 1), N)
           layers = (layers..., Dense(ones(1), ones(1), identity))
           @btime fun_anymutable3($layers)
         end
       end
test_3 (generic function with 1 method)

julia> test_3()
[ Info: 5
  min 0.083 ns, mean 0.185 ns (0 allocations)
[ Info: 6
  min 0.083 ns, mean 0.208 ns (0 allocations)
[ Info: 7
  min 0.083 ns, mean 0.229 ns (0 allocations)

julia> VERSION
v"1.9.0-DEV.1528"

(Edit -- inserted results)

Copy link
Member

@ToucheSir ToucheSir Oct 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, that still allocates (more, as a matter of fact) for me on nightly and 1.8. At least performance is consistent though.

[ Info: 5
  176.931 ns (7 allocations: 1.05 KiB)
[ Info: 6
  177.480 ns (7 allocations: 1.23 KiB)
[ Info: 7
  188.679 ns (7 allocations: 1.34 KiB)

julia> VERSION
v"1.9.0-DEV.1547"

@generated function anymutable(x::T) where {T}
ismutabletype(T) && return true
fs = fieldnames(T)
isempty(fs) && return false
subs = [:(anymutable(getfield(x, $f))) for f in QuoteNode.(fs)]
return :(|($(subs...)))
end

struct NoKeyword end

function fmap(f, x; exclude = isleaf, walk = _default_walk, cache = IdDict(), prune = NoKeyword())
haskey(cache, x) && return prune isa NoKeyword ? cache[x] : prune
cache[x] = exclude(x) ? f(x) : walk(x -> fmap(f, x; exclude=exclude, walk=walk, cache=cache, prune=prune), x)
end
function fmap(f, x; exclude = isleaf, walk = _default_walk, cache = anymutable(x) ? IdDict() : nothing, prune = NoKeyword())
if usecache(cache, x) && haskey(cache, x)
return prune isa NoKeyword ? cache[x] : prune
end
ret = if exclude(x)
f(x)
else
walk(x -> fmap(f, x; exclude, walk, cache, prune), x)
end
if usecache(cache, x)
cache[x] = ret
end
ret
end

###
### Extras
Expand All @@ -69,9 +96,19 @@ end
### Vararg forms
###

function fmap(f, x, ys...; exclude = isleaf, walk = _default_walk, cache = IdDict(), prune = NoKeyword())
haskey(cache, x) && return prune isa NoKeyword ? cache[x] : prune
cache[x] = exclude(x) ? f(x, ys...) : walk((xy...,) -> fmap(f, xy...; exclude=exclude, walk=walk, cache=cache, prune=prune), x, ys...)
function fmap(f, x, ys...; exclude = isleaf, walk = _default_walk, cache = anymutable(x) ? IdDict() : nothing, prune = NoKeyword())
if usecache(cache, x) && haskey(cache, x)
return prune isa NoKeyword ? cache[x] : prune
end
ret = if exclude(x)
f(x, ys...)
else
walk((xy...,) -> fmap(f, xy...; exclude, walk, cache, prune), x, ys...)
end
if usecache(cache, x)
cache[x] = ret
end
ret
end

function _default_walk(f, x, ys...)
Expand Down Expand Up @@ -108,3 +145,16 @@ end
macro flexiblefunctor(args...)
flexiblefunctorm(args...)
end

###
### Compat
###

if VERSION < v"1.7"
# Copied verbatim from Base, except omitting the macro:
function ismutabletype(@nospecialize t)
# @_total_meta
t = unwrap_unionall(t)
return isa(t, DataType) && t.name.flags & 0x2 == 0x2
end
end
49 changes: 44 additions & 5 deletions test/basics.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

using Functors: functor
using Functors: functor, usecache

struct Foo; x; y; end
@functor Foo
Expand All @@ -14,6 +14,7 @@ struct NoChildren2; x; y; end

struct NoChild{T}; x::T; end


###
### Basic functionality
###
Expand Down Expand Up @@ -47,28 +48,66 @@ end
@test (model′.x, model′.y, model′.z) == (1, 4, 3)
end

@testset "cache" begin
@testset "Sharing" begin
shared = [1,2,3]
m1 = Foo(shared, Foo([1,2,3], Foo(shared, [1,2,3])))
m1f = fmap(float, m1)
@test m1f.x === m1f.y.y.x
@test m1f.x !== m1f.y.x
m1p = fmapstructure(identity, m1; prune = nothing)
@test m1p == (x = [1, 2, 3], y = (x = [1, 2, 3], y = (x = nothing, y = [1, 2, 3])))
m1no = fmap(float, m1; cache = nothing) # disable the cache by hand
@test m1no.x !== m1no.y.y.x

# A non-leaf node can also be repeated:
# Here "4" is not shared, because Foo isn't leaf:
m2 = Foo(Foo(shared, 4), Foo(shared, 4))
@test m2.x === m2.y
m2f = fmap(float, m2)
@test m2f.x.x === m2f.y.x
m2p = fmapstructure(identity, m2; prune = Bar(0))
@test m2p == (x = (x = [1, 2, 3], y = 4), y = Bar(0))
@test m2p == (x = (x = [1, 2, 3], y = 4), y = (x = Bar{Int64}(0), y = 4))

# Repeated isbits types should not automatically be regarded as shared:
m3 = Foo(Foo(shared, 1:3), Foo(1:3, shared))
m3p = fmapstructure(identity, m3; prune = 0)
@test m3p.y.y == 0
@test_broken m3p.y.x == 1:3
@test m3p.y.x == 1:3

# All-isbits trees need not create a cache at all:
m4 = (x=1, y=(2, 3), z=4:5)
@test isbits(fmap(float, m4))
@test_skip 0 == @allocated fmap(float, m4) # true, but fails in tests

# Shared mutable containers are preserved, even if all children are isbits:
ref = Ref(1)
m5 = (x = ref, y = ref, z = Ref(1))
m5f = fmap(x -> x/2, m5)
@test m5f.x === m5f.y
@test m5f.x !== m5f.z

@testset "usecache" begin
d = IdDict()

# Leaf types:
@test usecache(d, [1,2])
@test !usecache(d, 4.0)
@test usecache(d, NoChild([1,2]))
@test !usecache(d, NoChild((3,4)))

# Not leaf:
@test usecache(d, Ref(3)) # mutable container
@test !usecache(d, (5, 6.0))
@test !usecache(d, (a = 2pi, b = missing))

@test !usecache(d, (5, [6.0]')) # contains mutable
@test !usecache(d, (x = [1,2,3], y = 4))

usecache(d, OneChild3([1,2], 3, nothing)) # mutable isn't a child, do we care?

# No dictionary:
@test !usecache(nothing, [1,2])
@test !usecache(nothing, 3)
end
end

@testset "functor(typeof(x), y) from @functor" begin
Expand Down