Skip to content

Commit

Permalink
Revert "Test if ChainRules problem is resolved"
Browse files Browse the repository at this point in the history
This reverts commit 1daeb59.
  • Loading branch information
avik-pal committed Apr 9, 2024
1 parent 1daeb59 commit 2a20c5a
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 1 deletion.
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Lux"
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.5.34"
version = "0.5.33"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -29,6 +29,7 @@ WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b"
Expand All @@ -42,6 +43,7 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
LuxChainRulesExt = "ChainRules"
LuxComponentArraysExt = "ComponentArrays"
LuxComponentArraysReverseDiffExt = ["ComponentArrays", "ReverseDiff"]
LuxFluxExt = "Flux"
Expand All @@ -61,6 +63,7 @@ Adapt = "4"
Aqua = "0.8.4"
ArrayInterface = "7.8"
CUDA = "5.2"
ChainRules = "1.62"
ChainRulesCore = "1.21"
ComponentArrays = "0.15.11"
ConcreteStructs = "0.2.3"
Expand Down Expand Up @@ -105,6 +108,7 @@ julia = "1.10"
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
Expand Down
18 changes: 18 additions & 0 deletions ext/LuxChainRulesExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
module LuxChainRulesExt

using ChainRules: ChainRules

# https://github.com/FluxML/Zygote.jl/pull/1328 broke the RNNs completely. Putting an
# emergency patch here
function ChainRules._setindex_zero(
x::Vector{<:AbstractArray{T}}, dy, inds::Integer...) where {T <: Number}
return [fill!(similar(xᵢ), 0) for xᵢ in x]
end

function ChainRules.∇getindex!(
dx::Vector{<:AbstractArray{T}}, dy, inds::Integer...) where {T <: Number}
dx[inds...] .+= dy
return dx
end

end

1 comment on commit 2a20c5a

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark Results

Benchmark suite Current: 2a20c5a Previous: abdbf4b Ratio
Dense(2 => 2)/cpu/reverse/ReverseDiff (compiled)/(2, 128) 3655.5 ns 3256 ns 1.12
Dense(2 => 2)/cpu/reverse/Zygote/(2, 128) 8373.8 ns 7608.416666666667 ns 1.10
Dense(2 => 2)/cpu/reverse/Tracker/(2, 128) 18875 ns 14687 ns 1.29
Dense(2 => 2)/cpu/reverse/ReverseDiff/(2, 128) 9946.6 ns 9830.4 ns 1.01
Dense(2 => 2)/cpu/reverse/Flux/(2, 128) 8994.5 ns 8746.333333333334 ns 1.03
Dense(2 => 2)/cpu/reverse/SimpleChains/(2, 128) 4253.555555555556 ns 4170 ns 1.02
Dense(2 => 2)/cpu/forward/NamedTuple/(2, 128) 1996.7 ns 2007.8 ns 0.99
Dense(2 => 2)/cpu/forward/ComponentArray/(2, 128) 1654.0845070422536 ns 1660.3197278911564 ns 1.00
Dense(2 => 2)/cpu/forward/Flux/(2, 128) 1806.6203703703704 ns 1839.8048780487804 ns 0.98
Dense(2 => 2)/cpu/forward/SimpleChains/(2, 128) 202.8326241134752 ns 179.2422969187675 ns 1.13
Dense(20 => 20)/cpu/reverse/ReverseDiff (compiled)/(20, 128) 17643 ns 17402 ns 1.01
Dense(20 => 20)/cpu/reverse/Zygote/(20, 128) 18695 ns 18545 ns 1.01
Dense(20 => 20)/cpu/reverse/Tracker/(20, 128) 36698 ns 35837 ns 1.02
Dense(20 => 20)/cpu/reverse/ReverseDiff/(20, 128) 29145 ns 28814 ns 1.01
Dense(20 => 20)/cpu/reverse/Flux/(20, 128) 19827 ns 19641.5 ns 1.01
Dense(20 => 20)/cpu/reverse/SimpleChains/(20, 128) 16170 ns 16220 ns 1.00
Dense(20 => 20)/cpu/forward/NamedTuple/(20, 128) 4955 ns 4826.285714285715 ns 1.03
Dense(20 => 20)/cpu/forward/ComponentArray/(20, 128) 5005.142857142857 ns 4874.857142857143 ns 1.03
Dense(20 => 20)/cpu/forward/Flux/(20, 128) 5036.714285714285 ns 4870.5 ns 1.03
Dense(20 => 20)/cpu/forward/SimpleChains/(20, 128) 1651.1 ns 1659.1 ns 1.00
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 3, 128) 50939380 ns 40919886 ns 1.24
Conv((3, 3), 3 => 3)/cpu/reverse/Zygote/(64, 64, 3, 128) 79931782 ns 105438282.5 ns 0.76
Conv((3, 3), 3 => 3)/cpu/reverse/Tracker/(64, 64, 3, 128) 78661421 ns 82547287 ns 0.95
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff/(64, 64, 3, 128) 91272190.5 ns 105389107 ns 0.87
Conv((3, 3), 3 => 3)/cpu/reverse/Flux/(64, 64, 3, 128) 93377866.5 ns 101432513 ns 0.92
Conv((3, 3), 3 => 3)/cpu/reverse/SimpleChains/(64, 64, 3, 128) 11991869 ns 12101555.5 ns 0.99
Conv((3, 3), 3 => 3)/cpu/forward/NamedTuple/(64, 64, 3, 128) 18984863.5 ns 12139914.5 ns 1.56
Conv((3, 3), 3 => 3)/cpu/forward/ComponentArray/(64, 64, 3, 128) 18867999 ns 18273296.5 ns 1.03
Conv((3, 3), 3 => 3)/cpu/forward/Flux/(64, 64, 3, 128) 18628681 ns 17955192 ns 1.04
Conv((3, 3), 3 => 3)/cpu/forward/SimpleChains/(64, 64, 3, 128) 6427121 ns 6406788 ns 1.00
vgg16/cpu/reverse/Zygote/(32, 32, 3, 1) 109345995 ns 103984169.5 ns 1.05
vgg16/cpu/reverse/Zygote/(32, 32, 3, 16) 753707936 ns 842558211 ns 0.89
vgg16/cpu/reverse/Zygote/(32, 32, 3, 64) 2956953751 ns 3036962183 ns 0.97
vgg16/cpu/reverse/Tracker/(32, 32, 3, 1) 182122386 ns 158105613 ns 1.15
vgg16/cpu/reverse/Tracker/(32, 32, 3, 16) 1136980410 ns 1091539870.5 ns 1.04
vgg16/cpu/reverse/Tracker/(32, 32, 3, 64) 3683425574 ns 4156720070 ns 0.89
vgg16/cpu/reverse/Flux/(32, 32, 3, 1) 85689271 ns 87153400 ns 0.98
vgg16/cpu/reverse/Flux/(32, 32, 3, 16) 736530234.5 ns 677891576.5 ns 1.09
vgg16/cpu/reverse/Flux/(32, 32, 3, 64) 2815110663 ns 3087144996 ns 0.91
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 1) 25737587 ns 25057694 ns 1.03
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 16) 212298522 ns 235060201.5 ns 0.90
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 64) 932573050.5 ns 850027203 ns 1.10
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 1) 26360056 ns 26539452 ns 0.99
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 16) 212978713 ns 222485100 ns 0.96
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 64) 845080145.5 ns 843941878 ns 1.00
vgg16/cpu/forward/Flux/(32, 32, 3, 1) 23696110 ns 23351068 ns 1.01
vgg16/cpu/forward/Flux/(32, 32, 3, 16) 188727204 ns 185662518 ns 1.02
vgg16/cpu/forward/Flux/(32, 32, 3, 64) 716121501 ns 816281223 ns 0.88
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 64, 128) 1108752997 ns 1134116904.5 ns 0.98
Conv((3, 3), 64 => 64)/cpu/reverse/Zygote/(64, 64, 64, 128) 1867146644 ns 1821706018 ns 1.02
Conv((3, 3), 64 => 64)/cpu/reverse/Tracker/(64, 64, 64, 128) 2187251601.5 ns 2165104540 ns 1.01
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff/(64, 64, 64, 128) 2374190547 ns 2350684531 ns 1.01
Conv((3, 3), 64 => 64)/cpu/reverse/Flux/(64, 64, 64, 128) 1864310613.5 ns 1833091167 ns 1.02
Conv((3, 3), 64 => 64)/cpu/forward/NamedTuple/(64, 64, 64, 128) 362695061 ns 359054624 ns 1.01
Conv((3, 3), 64 => 64)/cpu/forward/ComponentArray/(64, 64, 64, 128) 370586681.5 ns 458776503 ns 0.81
Conv((3, 3), 64 => 64)/cpu/forward/Flux/(64, 64, 64, 128) 371425922.5 ns 353291618 ns 1.05
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 1, 128) 12052047 ns 11907971 ns 1.01
Conv((3, 3), 1 => 1)/cpu/reverse/Zygote/(64, 64, 1, 128) 18235365.5 ns 18075976.5 ns 1.01
Conv((3, 3), 1 => 1)/cpu/reverse/Tracker/(64, 64, 1, 128) 19603792 ns 19244630 ns 1.02
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff/(64, 64, 1, 128) 24195903 ns 23929082 ns 1.01
Conv((3, 3), 1 => 1)/cpu/reverse/Flux/(64, 64, 1, 128) 18334453 ns 18071045 ns 1.01
Conv((3, 3), 1 => 1)/cpu/reverse/SimpleChains/(64, 64, 1, 128) 1154292 ns 1162609 ns 0.99
Conv((3, 3), 1 => 1)/cpu/forward/NamedTuple/(64, 64, 1, 128) 2110196 ns 2078115 ns 1.02
Conv((3, 3), 1 => 1)/cpu/forward/ComponentArray/(64, 64, 1, 128) 2170153 ns 2088017.5 ns 1.04
Conv((3, 3), 1 => 1)/cpu/forward/Flux/(64, 64, 1, 128) 2134836 ns 2077539 ns 1.03
Conv((3, 3), 1 => 1)/cpu/forward/SimpleChains/(64, 64, 1, 128) 202419 ns 197480 ns 1.03
Dense(200 => 200)/cpu/reverse/ReverseDiff (compiled)/(200, 128) 306714 ns 299121 ns 1.03
Dense(200 => 200)/cpu/reverse/Zygote/(200, 128) 277821 ns 274264 ns 1.01
Dense(200 => 200)/cpu/reverse/Tracker/(200, 128) 372783 ns 366778 ns 1.02
Dense(200 => 200)/cpu/reverse/ReverseDiff/(200, 128) 418944 ns 413700.5 ns 1.01
Dense(200 => 200)/cpu/reverse/Flux/(200, 128) 279253.5 ns 275296 ns 1.01
Dense(200 => 200)/cpu/reverse/SimpleChains/(200, 128) 400691 ns 395864 ns 1.01
Dense(200 => 200)/cpu/forward/NamedTuple/(200, 128) 93004 ns 88767 ns 1.05
Dense(200 => 200)/cpu/forward/ComponentArray/(200, 128) 96280 ns 89553 ns 1.08
Dense(200 => 200)/cpu/forward/Flux/(200, 128) 89628 ns 87284 ns 1.03
Dense(200 => 200)/cpu/forward/SimpleChains/(200, 128) 105237 ns 104536 ns 1.01
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 16, 128) 200438157 ns 197678304 ns 1.01
Conv((3, 3), 16 => 16)/cpu/reverse/Zygote/(64, 64, 16, 128) 352793790 ns 349707282 ns 1.01
Conv((3, 3), 16 => 16)/cpu/reverse/Tracker/(64, 64, 16, 128) 408772280 ns 394767610 ns 1.04
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff/(64, 64, 16, 128) 490040365 ns 477350913 ns 1.03
Conv((3, 3), 16 => 16)/cpu/reverse/Flux/(64, 64, 16, 128) 383463253.5 ns 371954865 ns 1.03
Conv((3, 3), 16 => 16)/cpu/reverse/SimpleChains/(64, 64, 16, 128) 368210985 ns 335078971.5 ns 1.10
Conv((3, 3), 16 => 16)/cpu/forward/NamedTuple/(64, 64, 16, 128) 53892602.5 ns 53540565 ns 1.01
Conv((3, 3), 16 => 16)/cpu/forward/ComponentArray/(64, 64, 16, 128) 53194010 ns 49765921.5 ns 1.07
Conv((3, 3), 16 => 16)/cpu/forward/Flux/(64, 64, 16, 128) 60542610 ns 49896033.5 ns 1.21
Conv((3, 3), 16 => 16)/cpu/forward/SimpleChains/(64, 64, 16, 128) 29155185 ns 28103680 ns 1.04
Dense(2000 => 2000)/cpu/reverse/ReverseDiff (compiled)/(2000, 128) 19965939.5 ns 19642944.5 ns 1.02
Dense(2000 => 2000)/cpu/reverse/Zygote/(2000, 128) 20208135 ns 19748348.5 ns 1.02
Dense(2000 => 2000)/cpu/reverse/Tracker/(2000, 128) 24305621.5 ns 23593738 ns 1.03
Dense(2000 => 2000)/cpu/reverse/ReverseDiff/(2000, 128) 24684238 ns 24233000.5 ns 1.02
Dense(2000 => 2000)/cpu/reverse/Flux/(2000, 128) 20154517 ns 19740162 ns 1.02
Dense(2000 => 2000)/cpu/forward/NamedTuple/(2000, 128) 6741537.5 ns 6615632.5 ns 1.02
Dense(2000 => 2000)/cpu/forward/ComponentArray/(2000, 128) 6772766 ns 6593397.5 ns 1.03
Dense(2000 => 2000)/cpu/forward/Flux/(2000, 128) 6680651 ns 6506655 ns 1.03

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.