diff --git a/docs/src/index.md b/docs/src/index.md index e1ff5b1..6d2c6b9 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -6,6 +6,6 @@ InverseFunctions This package defines the function [`inverse`](@ref). `inverse(f)` returns the inverse function of a function `f`, so that `inverse(f)(f(x)) ≈ x`. -`inverse` supports mapped/broadcasted functions (via `Base.Fix1`) and (on Julia >=v1.6) function composition. +`inverse` supports mapped/broadcasted functions (via `Base.Broadcast.BroadcastFunction` or `Base.Fix1`) and (on Julia >=v1.6) function composition. Implementations of `inverse(f)` for `identity`, `inv`, `adjoint` and `transpose` as well as for `exp`, `log`, `exp2`, `log2`, `exp10`, `log10`, `expm1`, `log1p` and `sqrt` are included. diff --git a/src/inverse.jl b/src/inverse.jl index ee3d1c6..67c4efa 100644 --- a/src/inverse.jl +++ b/src/inverse.jl @@ -5,8 +5,9 @@ Return the inverse of function `f`. -`inverse` supports mapped and broadcasted functions (via `Base.Fix1`) and -function composition (requires Julia >= 1.6). +`inverse` supports mapped and broadcasted functions (via +`Base.Broadcast.BroadcastFunction` or `Base.Fix1`) and function composition +(requires Julia >= 1.6). # Examples @@ -27,7 +28,7 @@ true julia> inverse(inverse(foo)) === foo true -julia> broadcast_foo = Base.Fix1(broadcast, foo); +julia> broadcast_foo = VERSION >= v"1.6" ? Base.Broadcast.BroadcastFunction(foo) : Base.Fix1(broadcast, foo); julia> X = rand(10); @@ -84,6 +85,15 @@ inverse(::typeof(inverse)) = inverse Base.ComposedFunction(inv_inner, inv_outer) end end + + function inverse(bf::Base.Broadcast.BroadcastFunction) + inv_f_kernel = inverse(bf.f) + if inv_f_kernel isa NoInverse + NoInverse(bf) + else + Base.Broadcast.BroadcastFunction(inv_f_kernel) + end + end end function inverse(mapped_f::Base.Fix1{<:Union{typeof(map),typeof(broadcast)}}) @@ -91,7 +101,7 @@ function inverse(mapped_f::Base.Fix1{<:Union{typeof(map),typeof(broadcast)}}) if inv_f_kernel isa NoInverse NoInverse(mapped_f) else - Base.Fix1(mapped_f.f, inverse(mapped_f.x)) + Base.Fix1(mapped_f.f, inv_f_kernel) end end diff --git a/test/test_inverse.jl b/test/test_inverse.jl index a1a8687..0bd33f4 100644 --- a/test/test_inverse.jl +++ b/test/test_inverse.jl @@ -20,12 +20,18 @@ InverseFunctions.inverse(f::Bar) = Bar(inv(f.A)) @testset "inverse" begin + @static if VERSION >= v"1.6" + _bc_func(f) = Base.Broadcast.BroadcastFunction(f) + else + _bc_func(f) = Base.Fix1(broadcast, f) + end + f_without_inverse(x) = 1 @test inverse(f_without_inverse) isa NoInverse @test_throws ErrorException inverse(f_without_inverse)(42) @test inverse(inverse(f_without_inverse)) === f_without_inverse - for f in (f_without_inverse ∘ exp, exp ∘ f_without_inverse, Base.Fix1(broadcast, f_without_inverse), Base.Fix1(map, f_without_inverse)) + for f in (f_without_inverse ∘ exp, exp ∘ f_without_inverse, _bc_func(f_without_inverse), Base.Fix1(broadcast, f_without_inverse), Base.Fix1(map, f_without_inverse)) @test inverse(f) == NoInverse(f) @test inverse(inverse(f)) == f end @@ -96,7 +102,7 @@ InverseFunctions.inverse(f::Bar) = Bar(inv(f.A)) end X = rand(5) - for f in (Base.Fix1(broadcast, foo), Base.Fix1(map, foo)) + for f in (_bc_func(foo), Base.Fix1(broadcast, foo), Base.Fix1(map, foo)) for x in (x, fill(x, 3), X) InverseFunctions.test_inverse(f, x) end