Skip to content

Commit

Permalink
Support Base.Broadcast.BroadcastFunction
Browse files Browse the repository at this point in the history
  • Loading branch information
oschulz committed Mar 28, 2023
1 parent 0eefe9d commit fd3396d
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 7 deletions.
2 changes: 1 addition & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@ InverseFunctions

This package defines the function [`inverse`](@ref). `inverse(f)` returns the inverse function of a function `f`, so that `inverse(f)(f(x)) ≈ x`.

`inverse` supports mapped/broadcasted functions (via `Base.Fix1`) and (on Julia >=v1.6) function composition.
`inverse` supports mapped/broadcasted functions (via `Base.Broadcast.BroadcastFunction` or `Base.Fix1`) and (on Julia >=v1.6) function composition.

Implementations of `inverse(f)` for `identity`, `inv`, `adjoint` and `transpose` as well as for `exp`, `log`, `exp2`, `log2`, `exp10`, `log10`, `expm1`, `log1p` and `sqrt` are included.
18 changes: 14 additions & 4 deletions src/inverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
Return the inverse of function `f`.
`inverse` supports mapped and broadcasted functions (via `Base.Fix1`) and
function composition (requires Julia >= 1.6).
`inverse` supports mapped and broadcasted functions (via
`Base.Broadcast.BroadcastFunction` or `Base.Fix1`) and function composition
(requires Julia >= 1.6).
# Examples
Expand All @@ -27,7 +28,7 @@ true
julia> inverse(inverse(foo)) === foo
true
julia> broadcast_foo = Base.Fix1(broadcast, foo);
julia> broadcast_foo = VERSION >= v"1.6" ? Base.Broadcast.BroadcastFunction(foo) : Base.Fix1(broadcast, foo);
julia> X = rand(10);
Expand Down Expand Up @@ -84,14 +85,23 @@ inverse(::typeof(inverse)) = inverse
Base.ComposedFunction(inv_inner, inv_outer)
end
end

function inverse(bf::Base.Broadcast.BroadcastFunction)
inv_f_kernel = inverse(bf.f)
if inv_f_kernel isa NoInverse
NoInverse(bf)
else
Base.Broadcast.BroadcastFunction(inv_f_kernel)
end
end
end

function inverse(mapped_f::Base.Fix1{<:Union{typeof(map),typeof(broadcast)}})
inv_f_kernel = inverse(mapped_f.x)
if inv_f_kernel isa NoInverse
NoInverse(mapped_f)
else
Base.Fix1(mapped_f.f, inverse(mapped_f.x))
Base.Fix1(mapped_f.f, inv_f_kernel)
end
end

Expand Down
10 changes: 8 additions & 2 deletions test/test_inverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,18 @@ InverseFunctions.inverse(f::Bar) = Bar(inv(f.A))


@testset "inverse" begin
@static if VERSION >= v"1.6"
_bc_func(f) = Base.Broadcast.BroadcastFunction(f)
else
_bc_func(f) = Base.Fix1(broadcast, f)
end

f_without_inverse(x) = 1
@test inverse(f_without_inverse) isa NoInverse
@test_throws ErrorException inverse(f_without_inverse)(42)
@test inverse(inverse(f_without_inverse)) === f_without_inverse

for f in (f_without_inverse exp, exp f_without_inverse, Base.Fix1(broadcast, f_without_inverse), Base.Fix1(map, f_without_inverse))
for f in (f_without_inverse exp, exp f_without_inverse, _bc_func(f_without_inverse), Base.Fix1(broadcast, f_without_inverse), Base.Fix1(map, f_without_inverse))
@test inverse(f) == NoInverse(f)
@test inverse(inverse(f)) == f
end
Expand Down Expand Up @@ -96,7 +102,7 @@ InverseFunctions.inverse(f::Bar) = Bar(inv(f.A))
end

X = rand(5)
for f in (Base.Fix1(broadcast, foo), Base.Fix1(map, foo))
for f in (_bc_func(foo), Base.Fix1(broadcast, foo), Base.Fix1(map, foo))
for x in (x, fill(x, 3), X)
InverseFunctions.test_inverse(f, x)
end
Expand Down

0 comments on commit fd3396d

Please sign in to comment.