diff --git a/Project.toml b/Project.toml index 3af67f51c..ebe693ae7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Lux" uuid = "b2108857-7c20-44ae-9111-449ecde12c47" authors = ["Avik Pal and contributors"] -version = "1.4.0" +version = "1.4.1-DEV" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -79,7 +79,7 @@ DispatchDoctor = "0.4.12" Enzyme = "0.13.16" EnzymeCore = "0.8.6" FastClosures = "0.3.2" -Flux = "0.14.25" +Flux = "0.15" ForwardDiff = "0.10.36" FunctionWrappers = "1.1.3" Functors = "0.5" diff --git a/ext/LuxFluxExt.jl b/ext/LuxFluxExt.jl index d0f89b2b0..c77b24c04 100644 --- a/ext/LuxFluxExt.jl +++ b/ext/LuxFluxExt.jl @@ -211,11 +211,11 @@ function Lux.convert_flux_model( return Lux.GroupNorm(l.chs, l.G, l.λ; l.affine, epsilon=l.ϵ) end -const _INVALID_TRANSFORMATION_TYPES = Union{<:Flux.Recur} +# const _INVALID_TRANSFORMATION_TYPES = Union{} -function Lux.convert_flux_model(l::T; kwargs...) where {T <: _INVALID_TRANSFORMATION_TYPES} - throw(FluxModelConversionException("Transformation of type $(T) is not supported.")) -end +# function Lux.convert_flux_model(l::T; kwargs...) where {T <: _INVALID_TRANSFORMATION_TYPES} +# throw(FluxModelConversionException("Transformation of type $(T) is not supported.")) +# end for cell in (:RNNCell, :LSTMCell, :GRUCell) msg = "Recurrent Cell: $(cell) for Flux has semantical difference with Lux, \ diff --git a/test/runtests.jl b/test/runtests.jl index 130ea0275..6837b9ae0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -26,8 +26,7 @@ if ("all" in LUX_TEST_GROUP || "misc" in LUX_TEST_GROUP) push!(EXTRA_PKGS, Pkg.PackageSpec("MPI")) (BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda") && push!(EXTRA_PKGS, Pkg.PackageSpec("NCCL")) - # XXX: Reactivate once Flux is compatible with Functors 0.5 - # push!(EXTRA_PKGS, Pkg.PackageSpec("Flux")) + push!(EXTRA_PKGS, Pkg.PackageSpec("Flux")) end if !Sys.iswindows() diff --git a/test/transform/flux_tests.jl b/test/transform/flux_tests.jl index 38bc1c176..6b8676522 100644 --- a/test/transform/flux_tests.jl +++ b/test/transform/flux_tests.jl @@ -1,4 +1,4 @@ -@testitem "FromFluxAdaptor" setup=[SharedTestSetup] tags=[:misc] skip=:(true) begin +@testitem "FromFluxAdaptor" setup=[SharedTestSetup] tags=[:misc] begin import Flux toluxpsst = FromFluxAdaptor(; preserve_ps_st=true)