Skip to content

Commit

Permalink
Use new public feature (#2342)
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott authored Dec 1, 2023
1 parent 92f8982 commit ed5a3b8
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 0 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "0.14.6"
[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Expand Down Expand Up @@ -38,6 +39,7 @@ AMDGPU = "0.6, 0.7"
Adapt = "3.0"
CUDA = "4, 5"
ChainRulesCore = "1.12"
Compat = "4.10.0"
Functors = "0.4"
MLUtils = "0.4"
MacroTools = "0.5"
Expand Down
51 changes: 51 additions & 0 deletions src/Flux.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module Flux

using Base: tail
using Compat: @compat # for @compat public
using Preferences
using LinearAlgebra, Statistics, Random # standard lib
using MacroTools, Reexport, ProgressLogging, SpecialFunctions
Expand Down Expand Up @@ -30,6 +31,15 @@ export Chain, Dense, Embedding, Maxout, SkipConnection, Parallel, PairwiseFusion
fmap, cpu, gpu, f32, f64, f16, rand32, randn32, zeros32, ones32,
testmode!, trainmode!

@compat(public, ( # mark unexported symbols as API, on Julia 1.11
# modules
Losses,
# layers
Bilinear, Scale, dropout,
# utils
outputsize, state,
))

include("optimise/Optimise.jl")
using .Optimise
export Descent, Adam, Momentum, Nesterov, RMSProp,
Expand All @@ -47,6 +57,15 @@ using Adapt, Functors, OneHotArrays
include("utils.jl")
include("functor.jl")

@compat(public, (
# from OneHotArrays.jl
onehot, onehotbatch, onecold,
# from Functors.jl
functor, @functor,
# from Optimise/Train/Optimisers.jl
setup, update!, destructure, freeze!, adjust!, params, trainable
))

# Pirate error to catch a common mistake.
Functors.functor(::Type{<:MLUtils.DataLoader}, x) = error("`DataLoader` does not support Functors.jl, thus functions like `Flux.gpu` will not act on its contents.")

Expand All @@ -69,5 +88,37 @@ include("deprecations.jl")
include("losses/Losses.jl")
using .Losses

@compat(public, (
# init
glorot_uniform,
glorot_normal,
kaiming_uniform,
kaiming_normal,
truncated_normal,
orthogonal,
sparse_init,
identity_init,

# Losses
binary_focal_loss,
binarycrossentropy,
crossentropy,
dice_coeff_loss,
focal_loss,
hinge_loss,
huber_loss,
kldivergence,
label_smoothing,
logitbinarycrossentropy,
logitcrossentropy,
mae,
mse,
msle,
poisson_loss,
siamese_contrastive_loss,
squared_hinge_loss,
tversky_loss,
))


end # module

0 comments on commit ed5a3b8

Please sign in to comment.