-
Notifications
You must be signed in to change notification settings - Fork 63
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Revert "Test if ChainRules problem is resolved"
This reverts commit 1daeb59.
- Loading branch information
Showing
2 changed files
with
23 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
name = "Lux" | ||
uuid = "b2108857-7c20-44ae-9111-449ecde12c47" | ||
authors = ["Avik Pal <[email protected]> and contributors"] | ||
version = "0.5.34" | ||
version = "0.5.33" | ||
|
||
[deps] | ||
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" | ||
|
@@ -29,6 +29,7 @@ WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" | |
|
||
[weakdeps] | ||
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" | ||
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" | ||
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" | ||
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" | ||
LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" | ||
|
@@ -42,6 +43,7 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" | |
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" | ||
|
||
[extensions] | ||
LuxChainRulesExt = "ChainRules" | ||
LuxComponentArraysExt = "ComponentArrays" | ||
LuxComponentArraysReverseDiffExt = ["ComponentArrays", "ReverseDiff"] | ||
LuxFluxExt = "Flux" | ||
|
@@ -61,6 +63,7 @@ Adapt = "4" | |
Aqua = "0.8.4" | ||
ArrayInterface = "7.8" | ||
CUDA = "5.2" | ||
ChainRules = "1.62" | ||
ChainRulesCore = "1.21" | ||
ComponentArrays = "0.15.11" | ||
ConcreteStructs = "0.2.3" | ||
|
@@ -105,6 +108,7 @@ julia = "1.10" | |
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" | ||
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" | ||
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" | ||
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" | ||
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" | ||
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" | ||
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
module LuxChainRulesExt | ||
|
||
using ChainRules: ChainRules | ||
|
||
# https://github.com/FluxML/Zygote.jl/pull/1328 broke the RNNs completely. Putting an | ||
# emergency patch here | ||
function ChainRules._setindex_zero( | ||
x::Vector{<:AbstractArray{T}}, dy, inds::Integer...) where {T <: Number} | ||
return [fill!(similar(xᵢ), 0) for xᵢ in x] | ||
end | ||
|
||
function ChainRules.∇getindex!( | ||
dx::Vector{<:AbstractArray{T}}, dy, inds::Integer...) where {T <: Number} | ||
dx[inds...] .+= dy | ||
return dx | ||
end | ||
|
||
end |
2a20c5a
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Benchmark Results
Dense(2 => 2)/cpu/reverse/ReverseDiff (compiled)/(2, 128)
3655.5
ns3256
ns1.12
Dense(2 => 2)/cpu/reverse/Zygote/(2, 128)
8373.8
ns7608.416666666667
ns1.10
Dense(2 => 2)/cpu/reverse/Tracker/(2, 128)
18875
ns14687
ns1.29
Dense(2 => 2)/cpu/reverse/ReverseDiff/(2, 128)
9946.6
ns9830.4
ns1.01
Dense(2 => 2)/cpu/reverse/Flux/(2, 128)
8994.5
ns8746.333333333334
ns1.03
Dense(2 => 2)/cpu/reverse/SimpleChains/(2, 128)
4253.555555555556
ns4170
ns1.02
Dense(2 => 2)/cpu/forward/NamedTuple/(2, 128)
1996.7
ns2007.8
ns0.99
Dense(2 => 2)/cpu/forward/ComponentArray/(2, 128)
1654.0845070422536
ns1660.3197278911564
ns1.00
Dense(2 => 2)/cpu/forward/Flux/(2, 128)
1806.6203703703704
ns1839.8048780487804
ns0.98
Dense(2 => 2)/cpu/forward/SimpleChains/(2, 128)
202.8326241134752
ns179.2422969187675
ns1.13
Dense(20 => 20)/cpu/reverse/ReverseDiff (compiled)/(20, 128)
17643
ns17402
ns1.01
Dense(20 => 20)/cpu/reverse/Zygote/(20, 128)
18695
ns18545
ns1.01
Dense(20 => 20)/cpu/reverse/Tracker/(20, 128)
36698
ns35837
ns1.02
Dense(20 => 20)/cpu/reverse/ReverseDiff/(20, 128)
29145
ns28814
ns1.01
Dense(20 => 20)/cpu/reverse/Flux/(20, 128)
19827
ns19641.5
ns1.01
Dense(20 => 20)/cpu/reverse/SimpleChains/(20, 128)
16170
ns16220
ns1.00
Dense(20 => 20)/cpu/forward/NamedTuple/(20, 128)
4955
ns4826.285714285715
ns1.03
Dense(20 => 20)/cpu/forward/ComponentArray/(20, 128)
5005.142857142857
ns4874.857142857143
ns1.03
Dense(20 => 20)/cpu/forward/Flux/(20, 128)
5036.714285714285
ns4870.5
ns1.03
Dense(20 => 20)/cpu/forward/SimpleChains/(20, 128)
1651.1
ns1659.1
ns1.00
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 3, 128)
50939380
ns40919886
ns1.24
Conv((3, 3), 3 => 3)/cpu/reverse/Zygote/(64, 64, 3, 128)
79931782
ns105438282.5
ns0.76
Conv((3, 3), 3 => 3)/cpu/reverse/Tracker/(64, 64, 3, 128)
78661421
ns82547287
ns0.95
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff/(64, 64, 3, 128)
91272190.5
ns105389107
ns0.87
Conv((3, 3), 3 => 3)/cpu/reverse/Flux/(64, 64, 3, 128)
93377866.5
ns101432513
ns0.92
Conv((3, 3), 3 => 3)/cpu/reverse/SimpleChains/(64, 64, 3, 128)
11991869
ns12101555.5
ns0.99
Conv((3, 3), 3 => 3)/cpu/forward/NamedTuple/(64, 64, 3, 128)
18984863.5
ns12139914.5
ns1.56
Conv((3, 3), 3 => 3)/cpu/forward/ComponentArray/(64, 64, 3, 128)
18867999
ns18273296.5
ns1.03
Conv((3, 3), 3 => 3)/cpu/forward/Flux/(64, 64, 3, 128)
18628681
ns17955192
ns1.04
Conv((3, 3), 3 => 3)/cpu/forward/SimpleChains/(64, 64, 3, 128)
6427121
ns6406788
ns1.00
vgg16/cpu/reverse/Zygote/(32, 32, 3, 1)
109345995
ns103984169.5
ns1.05
vgg16/cpu/reverse/Zygote/(32, 32, 3, 16)
753707936
ns842558211
ns0.89
vgg16/cpu/reverse/Zygote/(32, 32, 3, 64)
2956953751
ns3036962183
ns0.97
vgg16/cpu/reverse/Tracker/(32, 32, 3, 1)
182122386
ns158105613
ns1.15
vgg16/cpu/reverse/Tracker/(32, 32, 3, 16)
1136980410
ns1091539870.5
ns1.04
vgg16/cpu/reverse/Tracker/(32, 32, 3, 64)
3683425574
ns4156720070
ns0.89
vgg16/cpu/reverse/Flux/(32, 32, 3, 1)
85689271
ns87153400
ns0.98
vgg16/cpu/reverse/Flux/(32, 32, 3, 16)
736530234.5
ns677891576.5
ns1.09
vgg16/cpu/reverse/Flux/(32, 32, 3, 64)
2815110663
ns3087144996
ns0.91
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 1)
25737587
ns25057694
ns1.03
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 16)
212298522
ns235060201.5
ns0.90
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 64)
932573050.5
ns850027203
ns1.10
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 1)
26360056
ns26539452
ns0.99
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 16)
212978713
ns222485100
ns0.96
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 64)
845080145.5
ns843941878
ns1.00
vgg16/cpu/forward/Flux/(32, 32, 3, 1)
23696110
ns23351068
ns1.01
vgg16/cpu/forward/Flux/(32, 32, 3, 16)
188727204
ns185662518
ns1.02
vgg16/cpu/forward/Flux/(32, 32, 3, 64)
716121501
ns816281223
ns0.88
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 64, 128)
1108752997
ns1134116904.5
ns0.98
Conv((3, 3), 64 => 64)/cpu/reverse/Zygote/(64, 64, 64, 128)
1867146644
ns1821706018
ns1.02
Conv((3, 3), 64 => 64)/cpu/reverse/Tracker/(64, 64, 64, 128)
2187251601.5
ns2165104540
ns1.01
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff/(64, 64, 64, 128)
2374190547
ns2350684531
ns1.01
Conv((3, 3), 64 => 64)/cpu/reverse/Flux/(64, 64, 64, 128)
1864310613.5
ns1833091167
ns1.02
Conv((3, 3), 64 => 64)/cpu/forward/NamedTuple/(64, 64, 64, 128)
362695061
ns359054624
ns1.01
Conv((3, 3), 64 => 64)/cpu/forward/ComponentArray/(64, 64, 64, 128)
370586681.5
ns458776503
ns0.81
Conv((3, 3), 64 => 64)/cpu/forward/Flux/(64, 64, 64, 128)
371425922.5
ns353291618
ns1.05
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 1, 128)
12052047
ns11907971
ns1.01
Conv((3, 3), 1 => 1)/cpu/reverse/Zygote/(64, 64, 1, 128)
18235365.5
ns18075976.5
ns1.01
Conv((3, 3), 1 => 1)/cpu/reverse/Tracker/(64, 64, 1, 128)
19603792
ns19244630
ns1.02
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff/(64, 64, 1, 128)
24195903
ns23929082
ns1.01
Conv((3, 3), 1 => 1)/cpu/reverse/Flux/(64, 64, 1, 128)
18334453
ns18071045
ns1.01
Conv((3, 3), 1 => 1)/cpu/reverse/SimpleChains/(64, 64, 1, 128)
1154292
ns1162609
ns0.99
Conv((3, 3), 1 => 1)/cpu/forward/NamedTuple/(64, 64, 1, 128)
2110196
ns2078115
ns1.02
Conv((3, 3), 1 => 1)/cpu/forward/ComponentArray/(64, 64, 1, 128)
2170153
ns2088017.5
ns1.04
Conv((3, 3), 1 => 1)/cpu/forward/Flux/(64, 64, 1, 128)
2134836
ns2077539
ns1.03
Conv((3, 3), 1 => 1)/cpu/forward/SimpleChains/(64, 64, 1, 128)
202419
ns197480
ns1.03
Dense(200 => 200)/cpu/reverse/ReverseDiff (compiled)/(200, 128)
306714
ns299121
ns1.03
Dense(200 => 200)/cpu/reverse/Zygote/(200, 128)
277821
ns274264
ns1.01
Dense(200 => 200)/cpu/reverse/Tracker/(200, 128)
372783
ns366778
ns1.02
Dense(200 => 200)/cpu/reverse/ReverseDiff/(200, 128)
418944
ns413700.5
ns1.01
Dense(200 => 200)/cpu/reverse/Flux/(200, 128)
279253.5
ns275296
ns1.01
Dense(200 => 200)/cpu/reverse/SimpleChains/(200, 128)
400691
ns395864
ns1.01
Dense(200 => 200)/cpu/forward/NamedTuple/(200, 128)
93004
ns88767
ns1.05
Dense(200 => 200)/cpu/forward/ComponentArray/(200, 128)
96280
ns89553
ns1.08
Dense(200 => 200)/cpu/forward/Flux/(200, 128)
89628
ns87284
ns1.03
Dense(200 => 200)/cpu/forward/SimpleChains/(200, 128)
105237
ns104536
ns1.01
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 16, 128)
200438157
ns197678304
ns1.01
Conv((3, 3), 16 => 16)/cpu/reverse/Zygote/(64, 64, 16, 128)
352793790
ns349707282
ns1.01
Conv((3, 3), 16 => 16)/cpu/reverse/Tracker/(64, 64, 16, 128)
408772280
ns394767610
ns1.04
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff/(64, 64, 16, 128)
490040365
ns477350913
ns1.03
Conv((3, 3), 16 => 16)/cpu/reverse/Flux/(64, 64, 16, 128)
383463253.5
ns371954865
ns1.03
Conv((3, 3), 16 => 16)/cpu/reverse/SimpleChains/(64, 64, 16, 128)
368210985
ns335078971.5
ns1.10
Conv((3, 3), 16 => 16)/cpu/forward/NamedTuple/(64, 64, 16, 128)
53892602.5
ns53540565
ns1.01
Conv((3, 3), 16 => 16)/cpu/forward/ComponentArray/(64, 64, 16, 128)
53194010
ns49765921.5
ns1.07
Conv((3, 3), 16 => 16)/cpu/forward/Flux/(64, 64, 16, 128)
60542610
ns49896033.5
ns1.21
Conv((3, 3), 16 => 16)/cpu/forward/SimpleChains/(64, 64, 16, 128)
29155185
ns28103680
ns1.04
Dense(2000 => 2000)/cpu/reverse/ReverseDiff (compiled)/(2000, 128)
19965939.5
ns19642944.5
ns1.02
Dense(2000 => 2000)/cpu/reverse/Zygote/(2000, 128)
20208135
ns19748348.5
ns1.02
Dense(2000 => 2000)/cpu/reverse/Tracker/(2000, 128)
24305621.5
ns23593738
ns1.03
Dense(2000 => 2000)/cpu/reverse/ReverseDiff/(2000, 128)
24684238
ns24233000.5
ns1.02
Dense(2000 => 2000)/cpu/reverse/Flux/(2000, 128)
20154517
ns19740162
ns1.02
Dense(2000 => 2000)/cpu/forward/NamedTuple/(2000, 128)
6741537.5
ns6615632.5
ns1.02
Dense(2000 => 2000)/cpu/forward/ComponentArray/(2000, 128)
6772766
ns6593397.5
ns1.03
Dense(2000 => 2000)/cpu/forward/Flux/(2000, 128)
6680651
ns6506655
ns1.03
This comment was automatically generated by workflow using github-action-benchmark.