Skip to content

Commit

Permalink
Updates for 1.11 (#288)
Browse files Browse the repository at this point in the history
* Begin refactor for 1-11

* Tidy up a bit

* Tidy up a bit further

* Some work

* StandardFDataType includes NoFData

* Basic tests passing for Memory

* Some work

* Initial memory implementation + array tests

* Check number of arguments returned by rrule

* More imports

* Alterations to abstract interpreter for 1.11

* Ignore versioned manifests

* Qualify import

* Add lgetfield and order for getfield to utils

* Handle all lgetfield and getfield cases with memory, memoryref, and array

* Updates for 1.11

* Run CI on 1.11

* Improve utils for UnionAlls

* Do not inline away debug information

* Tidy up utilities

* Test utilities

* Update tangent functionality

* Fix some codual-related problems

* Fix edge case in memoryrefset

* Handle weird PiNode

* Test that PiNode edge case runs

* Remove commented-out code

* Remove redundant semicolon

* Remove commented-out code

* Move lsetfield definition to utils

* Improve remove_unreachable_blocks

* Uncomment a various array integration tests

* Add lsetfield to custom tangent type testing

* Rename remove_unreachable_block uses for mutating version

* Skip diff_test with compiler problems

* Move rules around to support both 1.11 and 1.10

* Add rule for sincos

* Run more tests on tasks

* Attempt to avoid allocations

* Disable Enzyme in benchmarking on 1.11

* Avoid having to define setfield and lsetfield rules

* Do not test tasks more stringently

* Fix task tests

* Restrict LuxLib versions

* Fix increment and tuple_map for NamedTuple and testing

* generalise lsetfield

* Protect more things for the GC

* Fix typo

* Revert changes to codual types

* Tidy up coduals

* Fix on 1.10

* Sort out 1.10 and 1

* Sort out CI names

* Fix up 1.10 testing

* Restrict runners

* Formatting

* Bump patch
  • Loading branch information
willtebbutt authored Oct 19, 2024
1 parent c9b569c commit 997398c
Show file tree
Hide file tree
Showing 42 changed files with 2,752 additions and 1,559 deletions.
2 changes: 1 addition & 1 deletion .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ steps:
- label: "Julia v1"
plugins:
- JuliaCI/julia#v1:
version: "1.10"
version: "1"
- JuliaCI/julia-test#v1: ~
- JuliaCI/julia-coverage#v1:
dirs:
Expand Down
16 changes: 12 additions & 4 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ concurrency:
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}
jobs:
test:
name: ${{ matrix.test_group }}
name: ${{ matrix.test_group }}-${{ matrix.version }}
runs-on: ubuntu-latest
if: github.event_name != 'schedule'
strategy:
Expand All @@ -31,11 +31,19 @@ jobs:
- 'integration_testing/array'
- 'integration_testing/turing'
- 'integration_testing/temporalgps'
version:
- '1'
include:
- test_group: 'basic'
version: '1.10'
- test_group: 'integration_testing/turing'
version: '1.10'

steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
with:
version: '1.10'
version: ${{ matrix.version }}
arch: x64
include-all-prereleases: false
- uses: julia-actions/cache@v2
Expand All @@ -62,7 +70,7 @@ jobs:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
with:
version: '1.10'
version: '1'
arch: x64
include-all-prereleases: false
- uses: julia-actions/cache@v2
Expand All @@ -81,7 +89,7 @@ jobs:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
with:
version: '1.10'
version: '1'
arch: x64
include-all-prereleases: false
- uses: julia-actions/cache@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/documentation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
with:
version: '1.10'
version: '1'
arch: x64
include-all-prereleases: false
- name: Install dependencies
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
/Manifest.toml
/Manifest-v1.11.toml
dev
bench/Manifest.toml
analysis_results
Expand Down
11 changes: 8 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "0.4.11"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
DiffTests = "de460e47-3fe3-5279-bb4a-814414816d5d"
Expand All @@ -18,6 +19,7 @@ Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[weakdeps]
AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
Expand All @@ -27,6 +29,7 @@ NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

[extensions]
MooncakeAllocCheckExt = "AllocCheck"
MooncakeCUDAExt = "CUDA"
MooncakeDynamicPPLExt = "DynamicPPL"
MooncakeJETExt = "JET"
Expand All @@ -37,6 +40,7 @@ MooncakeSpecialFunctionsExt = "SpecialFunctions"

[compat]
ADTypes = "1.9"
AllocCheck = "0.2"
BenchmarkTools = "1"
CUDA = "5"
ChainRulesCore = "1"
Expand All @@ -50,18 +54,19 @@ FillArrays = "1"
Graphs = "1"
JET = "0.9"
LogDensityProblemsAD = "1"
LuxLib = "1.2"
LuxLib = "1.2 - 1.3.3"
MistyClosures = "1"
NNlib = "0.9"
PDMats = "0.11"
Setfield = "1"
SpecialFunctions = "2"
StableRNGs = "1"
TemporalGPs = "0.7"
julia = "1.10"
julia = "1"

[extras]
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
DiffTests = "de460e47-3fe3-5279-bb4a-814414816d5d"
Expand All @@ -82,4 +87,4 @@ TemporalGPs = "e155a3c4-0841-43e1-8b83-a0e4f03cc18f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["AbstractGPs", "BenchmarkTools", "CUDA", "DiffTests", "Distributions", "Documenter", "DynamicPPL", "FillArrays", "KernelFunctions", "JET", "LogDensityProblemsAD", "Lux", "LuxLib", "NNlib", "PDMats", "SpecialFunctions", "StableRNGs", "Test", "TemporalGPs"]
test = ["AbstractGPs", "AllocCheck", "BenchmarkTools", "CUDA", "DiffTests", "Distributions", "Documenter", "DynamicPPL", "FillArrays", "KernelFunctions", "JET", "LogDensityProblemsAD", "Lux", "LuxLib", "NNlib", "PDMats", "SpecialFunctions", "StableRNGs", "Test", "TemporalGPs"]
13 changes: 9 additions & 4 deletions bench/run_benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,13 @@ using Mooncake:
TestUtils,
_typeof

using Mooncake.TestUtils: _deepcopy, to_benchmark
using Mooncake.TestUtils: _deepcopy

function to_benchmark(__rrule!!::R, dx::Vararg{CoDual, N}) where {R, N}
dx_f = Mooncake.tuple_map(x -> CoDual(primal(x), Mooncake.fdata(tangent(x))), dx)
out, pb!! = __rrule!!(dx_f...)
return pb!!(Mooncake.zero_rdata(primal(out)))
end

function zygote_to_benchmark(ctx, x::Vararg{Any, N}) where {N}
out, pb = Zygote._pullback(ctx, x...)
Expand Down Expand Up @@ -124,6 +130,7 @@ should_run_benchmark(
should_run_benchmark(
::Val{:enzyme}, ::Base.Fix1{<:typeof(DynamicPPL.LogDensityProblems.logdensity)}, x...
) = false
should_run_benchmark(::Val{:enzyme}, x...) = false

@inline g(x, a, ::Val{N}) where {N} = N > 0 ? g(x * a, a, Val(N-1)) : x

Expand Down Expand Up @@ -277,9 +284,7 @@ function benchmark_hand_written_rrules!!(rng_ctor)
end

function benchmark_derived_rrules!!(rng_ctor)
test_case_data = map([
:test_utils
]) do s
test_case_data = map([:test_resources]) do s
test_cases, memory = generate_derived_rrule!!_test_cases(rng_ctor, Val(s))
ranges = map(x -> x[3], test_cases)
tags = fill(nothing, length(test_cases))
Expand Down
8 changes: 8 additions & 0 deletions ext/MooncakeAllocCheckExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
module MooncakeAllocCheckExt

using AllocCheck, Mooncake
import Mooncake.TestUtils: check_allocs, Shim

@check_allocs check_allocs(::Shim, f::F, x::Tuple{Vararg{Any, N}}) where {F, N} = f(x...)

end
6 changes: 6 additions & 0 deletions src/Mooncake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ include(joinpath("interpreter", "s2s_reverse_mode_ad.jl"))

include("tools_for_rules.jl")
include("test_utils.jl")
include("test_resources.jl")

include(joinpath("rrules", "avoiding_non_differentiable_code.jl"))
include(joinpath("rrules", "blas.jl"))
Expand All @@ -86,6 +87,11 @@ include(joinpath("rrules", "low_level_maths.jl"))
include(joinpath("rrules", "misc.jl"))
include(joinpath("rrules", "new.jl"))
include(joinpath("rrules", "tasks.jl"))
@static if VERSION >= v"1.11-rc4"
include(joinpath("rrules", "memory.jl"))
else
include(joinpath("rrules", "array_legacy.jl"))
end

include("interface.jl")
include("config.jl")
Expand Down
9 changes: 8 additions & 1 deletion src/codual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ Equivalent to `CoDual(x, zero_tangent(x))`.
"""
zero_codual(x) = CoDual(x, zero_tangent(x))

"""
uninit_codual(x)
Equivalent to `CoDual(x, uninit_tangent(x))`.
"""
uninit_codual(x) = CoDual(x, uninit_tangent(x))

"""
codual_type(P::Type)
Expand All @@ -32,7 +39,7 @@ The type of the `CoDual` which contains instances of `P` and associated tangents
function codual_type(::Type{P}) where {P}
P == DataType && return CoDual
P isa Union && return Union{codual_type(P.a), codual_type(P.b)}
P <: UnionAll && return CoDual
P <: UnionAll && return CoDual # P is abstract, so we don't know its tangent type.
return isconcretetype(P) ? CoDual{P, tangent_type(P)} : CoDual
end

Expand Down
2 changes: 1 addition & 1 deletion src/debug_mode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ _copy(x::P) where {P<:DebugRRule} = P(_copy(x.rule))
Apply type checking to enforce pre- and post-conditions on `rule.rule`. See the docstring
for `DebugRRule` for details.
"""
@inline function (rule::DebugRRule)(x::Vararg{CoDual, N}) where {N}
@noinline function (rule::DebugRRule)(x::Vararg{CoDual, N}) where {N}
verify_fwds_inputs(x)
y, pb = rule.rule(x...)
verify_fwds_output(x, y)
Expand Down
82 changes: 78 additions & 4 deletions src/interpreter/abstract_interpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@

# The most important bit of this code is `inlining_policy` -- the rest is copy + pasted
# boiler plate, largely taken from https://github.com/JuliaLang/julia/blob/2fe4190b3d26b4eee52b2b1b1054ddd6e38a941e/test/compiler/newinterp.jl#L11
#
# Credit: much of the code in here is copied over from the main Julia repo, and from
# Enzyme.jl, which has a very similar set of concerns to Mooncake in terms of avoiding
# inlining primitive functions.
#

struct ClosureCacheKey
world_age::UInt
Expand Down Expand Up @@ -54,6 +59,8 @@ end

MooncakeInterpreter() = MooncakeInterpreter(DefaultCtx)

context_type(::MooncakeInterpreter{C}) where {C} = C

# Globally cached interpreter. Should only be accessed via `get_interpreter`.
const GLOBAL_INTERPRETER = Ref(MooncakeInterpreter())

Expand All @@ -74,7 +81,6 @@ end

CC.InferenceParams(interp::MooncakeInterpreter) = interp.inf_params
CC.OptimizationParams(interp::MooncakeInterpreter) = interp.opt_params
CC.get_world_counter(interp::MooncakeInterpreter) = interp.world
CC.get_inference_cache(interp::MooncakeInterpreter) = interp.inf_cache
function CC.code_cache(interp::MooncakeInterpreter)
return CC.WorldView(interp.code_cache, CC.WorldRange(interp.world))
Expand All @@ -97,11 +103,60 @@ function CC.method_table(interp::MooncakeInterpreter)
return CC.OverlayMethodTable(interp.world, mooncake_method_table)
end

if VERSION < v"1.11.0"
CC.get_world_counter(interp::MooncakeInterpreter) = interp.world
get_inference_world(interp::CC.AbstractInterpreter) = CC.get_world_counter(interp)
else
CC.get_inference_world(interp::MooncakeInterpreter) = interp.world
CC.cache_owner(::MooncakeInterpreter) = nothing
get_inference_world(interp::CC.AbstractInterpreter) = CC.get_inference_world(interp)
end

_type(x) = x
_type(x::CC.Const) = _typeof(x.val)
_type(x::CC.PartialStruct) = x.typ
_type(x::CC.Conditional) = Union{_type(x.thentype), _type(x.elsetype)}

struct NoInlineCallInfo <: CC.CallInfo
info::CC.CallInfo # wrapped call
tt::Any # signature
end

CC.nsplit_impl(info::NoInlineCallInfo) = CC.nsplit(info.info)
CC.getsplit_impl(info::NoInlineCallInfo, idx::Int) = CC.getsplit(info.info, idx)
CC.getresult_impl(info::NoInlineCallInfo, idx::Int) = CC.getresult(info.info, idx)

function Core.Compiler.abstract_call_gf_by_type(
interp::MooncakeInterpreter{C},
@nospecialize(f),
arginfo::CC.ArgInfo,
si::CC.StmtInfo,
@nospecialize(atype),
sv::CC.AbsIntState,
max_methods::Int,
) where {C}
ret = @invoke CC.abstract_call_gf_by_type(
interp::CC.AbstractInterpreter,
f::Any,
arginfo::CC.ArgInfo,
si::CC.StmtInfo,
atype::Any,
sv::CC.AbsIntState,
max_methods::Int,
)
callinfo = ret.info
if Mooncake.is_primitive(C, atype)
callinfo = NoInlineCallInfo(callinfo, atype)
end
@static if VERSION v"1.11-"
return CC.CallMeta(ret.rt, ret.exct, ret.effects, callinfo)
else
return CC.CallMeta(ret.rt, ret.effects, callinfo)
end
end

if VERSION < v"1.11-"

function CC.inlining_policy(
interp::MooncakeInterpreter{C},
@nospecialize(src),
Expand All @@ -112,8 +167,7 @@ function CC.inlining_policy(
) where {C}

# Do not inline away primitives.
argtype_tuple = Tuple{map(_type, argtypes)...}
is_primitive(C, argtype_tuple) && return nothing
info isa NoInlineCallInfo && return nothing

# If not a primitive, AD doesn't care about it. Use the usual inlining strategy.
return @invoke CC.inlining_policy(
Expand All @@ -126,4 +180,24 @@ function CC.inlining_policy(
)
end

context_type(::MooncakeInterpreter{C}) where {C} = C
else # 1.11 and up.

function CC.inlining_policy(
interp::MooncakeInterpreter,
@nospecialize(src),
@nospecialize(info::CC.CallInfo),
stmt_flag::UInt32,
)
# Do not inline away primitives.
info isa NoInlineCallInfo && return nothing

# If not a primitive, AD doesn't care about it. Use the usual inlining strategy.
return @invoke CC.inlining_policy(
interp::CC.AbstractInterpreter,
src::Any,
info::CC.CallInfo,
stmt_flag::UInt32,
)
end

end
Loading

2 comments on commit 997398c

@willtebbutt
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request updated: JuliaRegistries/General/117349

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.4.11 -m "<description of version>" 997398c6189e0a609670c8492225f1c94ee05b65
git push origin v0.4.11

Please sign in to comment.