diff --git a/src/functions.jl b/src/functions.jl index 83f8380..509de30 100644 --- a/src/functions.jl +++ b/src/functions.jl @@ -12,13 +12,44 @@ function square(x::Real) end -invpow2(x::Real, p::Integer) = sign(x) * abs(x)^inv(p) -invpow2(x::Real, p::Real) = x ≥ zero(x) ? x^inv(p) : throw(DomainError(x, "inverse for x^$p is not defined at $x")) -invpow2(x, p) = x^inv(p) +function invpow2(x::Real, p::Integer) + if x ≥ zero(x) || isodd(p) + copysign(abs(x)^inv(p), x) + else + throw(DomainError(x, "inverse for x^$p is not defined at $x")) + end +end +function invpow2(x::Real, p::Real) + if x ≥ zero(x) + x^inv(p) + else + throw(DomainError(x, "inverse for x^$p is not defined at $x")) + end +end +function invpow2(x, p::Real) + # complex x^p is only invertible for p = 1/n + if isinteger(inv(p)) + x^inv(p) + else + throw(DomainError(x, "inverse for x^$p is not defined at $x")) + end +end -invpow1(b, x) = log(abs(b), abs(x)) +function invpow1(b::Real, x::Real) + if b ≥ zero(b) && x ≥ zero(x) + log(b, x) + else + throw(DomainError(x, "inverse for $b^x is not defined at $x")) + end +end -invlog1(b::Real, x::Real) = b ≥ zero(b) && x ≥ zero(x) ? b^x : throw(DomainError(x, "inverse for log($b, x) is not defined at $x")) +function invlog1(b::Real, x::Real) + if b ≥ zero(b) + b^x + else + throw(DomainError(x, "inverse for log($b, x) is not defined at $x")) + end +end invlog1(b, x) = b^x invlog2(b, x) = x^inv(b) diff --git a/test/test_inverse.jl b/test/test_inverse.jl index 57fd92a..fe8599b 100644 --- a/test/test_inverse.jl +++ b/test/test_inverse.jl @@ -37,7 +37,7 @@ InverseFunctions.inverse(f::Bar) = Bar(inv(f.A)) x = rand() for f in ( foo, inv_foo, log, log2, log10, log1p, sqrt, - Base.Fix2(^, rand()), Base.Fix2(^, rand([-10:-1; 1:10])), Base.Fix1(^, rand()), Base.Fix1(log, rand()), Base.Fix2(log, rand()), + Base.Fix2(^, rand()), Base.Fix2(^, rand([-10:-1; 1:10])), Base.Fix1(^, rand()), Base.Fix1(log, rand()), Base.Fix1(log, 1/rand()), Base.Fix2(log, rand()), ) InverseFunctions.test_inverse(f, x) end @@ -55,10 +55,16 @@ InverseFunctions.inverse(f::Bar) = Bar(inv(f.A)) # ensure that inverses have domains compatible with original functions @test_throws DomainError inverse(Base.Fix1(*, 0)) @test_throws DomainError inverse(Base.Fix2(^, 0)) + @test_throws DomainError inverse(Base.Fix1(log, -2))(5) @test_throws DomainError inverse(Base.Fix1(log, 2))(-5) - InverseFunctions.test_inverse(Base.Fix1(log, 2), -5 + 0im) + InverseFunctions.test_inverse(inverse(Base.Fix1(log, 2)), complex(-5)) @test_throws DomainError inverse(Base.Fix2(^, 0.5))(-5) - InverseFunctions.test_inverse(Base.Fix2(^, 0.5), -5 + 0im) + @test_throws DomainError inverse(Base.Fix2(^, 0.51))(complex(-5)) + InverseFunctions.test_inverse(Base.Fix2(^, 0.5), complex(-5)) + @test_throws DomainError inverse(Base.Fix2(^, 2))(-5) + @test_throws DomainError inverse(Base.Fix1(^, 2))(-5) + @test_throws DomainError inverse(Base.Fix1(^, -2))(3) + @test_throws DomainError inverse(Base.Fix1(^, -2))(3) A = rand(5, 5) for f in (