Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
Merge pull request #22 from LuxDL/ap/simplify_ext
Browse files Browse the repository at this point in the history
Use PackageExtensionCompat
  • Loading branch information
avik-pal authored Jun 26, 2023
2 parents d3274eb + a708262 commit 6980015
Show file tree
Hide file tree
Showing 7 changed files with 22 additions and 82 deletions.
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
name = "LuxLib"
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.2.4"
version = "0.2.5"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[weakdeps]
Expand All @@ -32,8 +32,8 @@ ForwardDiff = "0.10"
KernelAbstractions = "0.9"
LuxCUDA = "0.1"
NNlib = "0.8, 0.9"
PackageExtensionCompat = "1"
Reexport = "1"
Requires = "1"
ReverseDiff = "1"
Tracker = "0.2"
julia = "1.6"
Expand Down
3 changes: 1 addition & 2 deletions ext/LuxLibForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
module LuxLibForwardDiffExt

isdefined(Base, :get_extension) ? (using ForwardDiff) : (using ..ForwardDiff)
using LuxLib
using ForwardDiff, LuxLib

function LuxLib._dropout_fptype(x::AbstractArray{<:ForwardDiff.Dual})
return ForwardDiff.valtype(eltype(x))
Expand Down
3 changes: 1 addition & 2 deletions ext/LuxLibLuxCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
module LuxLibLuxCUDAExt

isdefined(Base, :get_extension) ? (using LuxCUDA) : (using ..LuxCUDA)
using LuxLib
using LuxCUDA, LuxLib
import ChainRulesCore as CRC
import LuxLib: batchnorm, _batchnorm_cudnn!, _get_batchnorm_statistics, FP_32_64, ∂∅

Expand Down
13 changes: 2 additions & 11 deletions ext/LuxLibLuxCUDATrackerExt.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,7 @@
module LuxLibLuxCUDATrackerExt

if isdefined(Base, :get_extension)
using Tracker
import Tracker: @grad, data, nobacksies, track, TrackedArray, TrackedVector, TrackedReal
using LuxCUDA
else
using ..Tracker
import ..Tracker: @grad,
data, nobacksies, track, TrackedArray, TrackedVector, TrackedReal
using ..LuxCUDA
end
using NNlib, LuxLib
using NNlib, LuxCUDA, LuxLib, Tracker
import Tracker: @grad, data, nobacksies, track, TrackedArray, TrackedVector, TrackedReal
import LuxLib: AA,
AV, _batchnorm_cudnn!, _get_batchnorm_statistics, FP_32_64, ∂∅, __is_tracked

Expand Down
37 changes: 11 additions & 26 deletions ext/LuxLibReverseDiffExt.jl
Original file line number Diff line number Diff line change
@@ -1,33 +1,18 @@
module LuxLibReverseDiffExt

if isdefined(Base, :get_extension)
using ReverseDiff
import ReverseDiff: SpecialInstruction,
TrackedArray,
TrackedReal,
decrement_deriv!,
increment_deriv!,
track,
value,
special_reverse_exec!,
special_forward_exec!,
@grad_from_chainrules
else
using ..ReverseDiff
import ..ReverseDiff: SpecialInstruction,
TrackedArray,
TrackedReal,
decrement_deriv!,
increment_deriv!,
track,
value,
special_reverse_exec!,
special_forward_exec!,
@grad_from_chainrules
end
using ChainRulesCore, LuxLib, NNlib
using ChainRulesCore, LuxLib, NNlib, ReverseDiff
import ChainRulesCore as CRC
import LuxLib: AA, __is_tracked
import ReverseDiff: SpecialInstruction,
TrackedArray,
TrackedReal,
decrement_deriv!,
increment_deriv!,
track,
value,
special_reverse_exec!,
special_forward_exec!,
@grad_from_chainrules

# Patches: Needs upstreaming
@inline function increment_deriv!(t::Union{TrackedArray, TrackedReal}, ::NoTangent, i)
Expand Down
11 changes: 2 additions & 9 deletions ext/LuxLibTrackerExt.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,10 @@
module LuxLibTrackerExt

if isdefined(Base, :get_extension)
using Tracker
import Tracker: @grad, data, nobacksies, track, TrackedArray, TrackedVector, TrackedReal
else
using ..Tracker
import ..Tracker: @grad,
data, nobacksies, track, TrackedArray, TrackedVector, TrackedReal
end
using NNlib, LuxLib
using NNlib, LuxLib, Tracker
import LuxLib: AA,
AV, _batchnorm_cudnn!, _get_batchnorm_statistics, FP_32_64, ∂∅, __is_tracked
import ChainRulesCore as CRC
import Tracker: @grad, data, nobacksies, track, TrackedArray, TrackedVector, TrackedReal

# NNlib: batched_mul
for T1 in (:AbstractArray, :TrackedArray), T2 in (:AbstractArray, :TrackedArray)
Expand Down
31 changes: 2 additions & 29 deletions src/LuxLib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,36 +11,9 @@ using KernelAbstractions
import KernelAbstractions as KA

# Extensions
if !isdefined(Base, :get_extension)
using Requires
end

using PackageExtensionCompat
function __init__()
@static if !isdefined(Base, :get_extension)
# Handling AD Packages
## Handling ForwardDiff
@require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" begin
include("../ext/LuxLibForwardDiffExt.jl")
end
## Handling Tracker
@require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin
include("../ext/LuxLibTrackerExt.jl")
end
## Handling ReverseDiff
@require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" begin
include("../ext/LuxLibReverseDiffExt.jl")
end

# Accelerator Support
## Handling CUDA
@require LuxCUDA="d0bbae9a-e099-4d5b-a835-1c6931763bda" begin
include("../ext/LuxLibLuxCUDAExt.jl")

@require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin
include("../ext/LuxLibLuxCUDATrackerExt.jl")
end
end
end
@require_extensions
end

include("utils.jl")
Expand Down

2 comments on commit 6980015

@avik-pal
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/86300

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.5 -m "<description of version>" 6980015626af1cc5d1dcc1669bc569fb242fdab3
git push origin v0.2.5

Please sign in to comment.