Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error differentiating through complex power due to isinteger #637

Open
tom-plaa opened this issue Mar 20, 2023 · 6 comments
Open

Error differentiating through complex power due to isinteger #637

tom-plaa opened this issue Mar 20, 2023 · 6 comments

Comments

@tom-plaa
Copy link

tom-plaa commented Mar 20, 2023

I have a weird case where a complex power is giving me an inexact error because the _cpow helper function determines my object to be an integer.

Truncated error message:

julia> ForwardDiff.gradient(test_func, ptest)
ERROR: InexactError: Int(Int64, Dual{ForwardDiff.Tag{typeof(test_func), Float64}}(1.0,-0.0,-0.8333333333333333,-0.0,-0.0,-0.17599538919114094))
Stacktrace:
  [1] Int64
    @ ~/.julia/packages/ForwardDiff/QdStj/src/dual.jl:364 [inlined]
  [2] convert(#unused#::Type{Int64}, x::ForwardDiff.Dual{ForwardDiff.Tag{typeof(test_func), Float64}, Float64, 5})                                  
    @ Base ./number.jl:7
  [3] _cpow(z::Complex{ForwardDiff.Dual{ForwardDiff.Tag{typeof(test_func), Float64}, Float64, 5}}, p::ForwardDiff.Dual{ForwardDiff.Tag{typeof(test_func), Float64}, Float64, 5})
    @ Base ./complex.jl:782
  [4] ^
    @ ./complex.jl:851 [inlined]
  [5] ^
    @ ./complex.jl:866 [inlined]

So given that

julia> typeof(test_func)
typeof(test_func) (singleton type of function test_func, subtype of Function)

It seems that we can get the following MWE:

julia> isinteger(Dual{ForwardDiff.Tag{Function, Float64}}(1.0,-0.0,-0.8333333333333333,-0.0,-0.0,-0.17599538919114094))
true

This is the associated culprit line in complex.jl that tries to do this conversion to Int (only the relevant part is shown):

# _cpow helper function to avoid method ambiguity with ^(::Complex,::Real)
function _cpow(z::Union{T,Complex{T}}, p::Union{T,Complex{T}}) where T
    z = float(z)
    p = float(p)
    Tf = float(T)
    if isreal(p)
        pᵣ = real(p)
        if isinteger(pᵣ) && abs(pᵣ) < typemax(Int32) 
            # |p| < typemax(Int32) serves two purposes: it prevents overflow
            # when converting p to Int, and it also turns out to be roughly
            # the crossover point for exp(p*log(z)) or similar to be faster.
            if iszero(pᵣ) # fix signs of imaginary part for z^0
                zer = flipsign(copysign(zero(Tf),pᵣ), imag(z))
                return Complex(one(Tf), zer)
            end
            ip = convert(Int, pᵣ) # <--- HERE!

Admittedly, I am not sure if the partials should be considered for determining if it's an integer or not.

@mcabbott
Copy link
Member

mcabbott commented Mar 20, 2023

What is test_func?

Guessing from the stack trace, I think the MWE is this:

julia> (Dual(1, 0.2) + 0im) ^ Dual(3, 0.1)
ERROR: InexactError: Int(Int64, Dual{Nothing}(3.0,0.1))
Stacktrace:
 [1] Int64
   @ ~/.julia/packages/ForwardDiff/vXysl/src/dual.jl:364 [inlined]
 [2] convert
   @ ./number.jl:7 [inlined]
 [3] _cpow(z::Complex{Dual{Nothing, Float64, 1}}, p::Dual{Nothing, Float64, 1})
   @ Base ./complex.jl:791
 [4] ^(z::Complex{Dual{Nothing, Float64, 1}}, p::Dual{Nothing, Float64, 1})
   @ Base ./complex.jl:860

Changing isinteger might be one way to fix this. The same function came up (without an example) here: #481 (comment)

@tom-plaa
Copy link
Author

test_func is a function that constructs a pdf from a fourier inverted characteristic version, and is an object from the Interpolations package. But I believe this is might indeed be related to your comment that you just linked. I wonder why I didn't face this before even though I have been differentiating through this particular piece of code for quite some time, and it still works well in the previous cases where I do it.

@tom-plaa
Copy link
Author

tom-plaa commented Mar 20, 2023

It's in this particular line from https://gitlab.com/tom.plaa/CharacteristicInvFourier.jl/-/blob/main/src/CharacteristicInvFourier.jl#L101

D = (-1+0im) .^ (-2 * (xmin / xrange) * k)

So yes, equivalent it seems. What's confusing to me is that it never gave that error before, and now I tried to declare a single local instance of that function inside another and suddenly everything breaks.

@mcabbott
Copy link
Member

mcabbott commented Mar 20, 2023

Ok. One guess might be that earlier operations made the dual part zero, or the real part non-integer? But I don't know. I don't think ForwardDiff has changed much.

Edit: Seeing broadcasting in your line, if you are using Zygote, its broadcasting has a fairly recent change. Before, this would use Dual only for real numbers, and something much slower otherwise. After, it uses Dual for complex numbers too.

Note that unlike the other 481 problems, I don't think this shortcut behind an isinteger check can lead to silent wrong answers. Converting to Int is better than trusting the check & then truncating.

julia> (Dual(1.1, 0.2) + 0im) ^ Dual(3, 0.0)  # takes the isinteger shortcut
Dual{Nothing}(1.3310000000000004,0.7260000000000002) + Dual{Nothing}(0.0,0.0)*im

julia> (Dual(1.1, 0.2) + 0im) ^ Dual(3.00001, 0.0)  # does not, should be very close
Dual{Nothing}(1.3310012685790982,0.7260031119545418) + Dual{Nothing}(0.0,0.0)*im

@tom-plaa
Copy link
Author

tom-plaa commented Mar 20, 2023

Interesting point, but do you mean I can try that temporary hack to get around the convertion for now?
In my particular case, if I do the following inside a given function it will not complain

        numerical_logpdf(u) = log(numerical_pdf(D(distcoefs...),
                                                npower=16, widthfactor=9.5)
                                  )

        fixedlogkernel(x) = (try sig=std(D(distcoefs...));
                                 numerical_logpdf(mean(D(distcoefs...)) + sig*x)+log(sig);
                             catch;
                                 T(-Inf);
                             end)

Where numerical_pdf is a call to the numerical inversion function mentioned before, D is a particular Type of mine such that D <: Distributions.ContinuousUnivariateDistribution and distcoefs is the differentiation space, vector towards which the gradient is being calculated. I'm basically redefining some functions from https://github.com/s-broda/ARCHModels.jl/blob/2dfe36ac7c21482e7bbf2390ab66bcaf53a503bd/src/univariatearchmodel.jl to use my own distribution type that needs to calculate the density numerically. It doesn't break but the results don't seem to work so far, but that's unrelated.

Here is the version that raises this error:

dist = D(distcoefs...)
mypdf = numerical_pdf(dist, npower=16, widthfactor=9.5)  # It breaks when the gradient calculation gets here
numerical_logpdf(u) = log(mypdf(u))

fixedlogkernel(x) = (try sig=std(dist);
                                          numerical_logpdf(mean(D(distcoefs...)) + sig*x)+log(sig);
                                   catch;
                                          T(-Inf);
                                   end)

I don't understand why declaring my function as a local instance breaks, but the other case doesn't.

@tom-plaa
Copy link
Author

tom-plaa commented Mar 21, 2023

I understood the difference, the try/catch statement was just hiding the error and assigning -Inf's all the way down.
So, what is the reason for which we cannot have
isinteger(x::Dual) = isinteger(x.value) && all(isinteger.(x.partials)) ?

It would be like:

julia> test_isinteger(x::Dual) = isinteger(x.value) && all(isinteger.(x.partials))
test_isinteger (generic function with 1 method)

julia> test_isinteger(Dual(1.0, 0.2))
false

julia> test_isinteger(Dual(1.0, 0.0))
true

Although we currently have

@inline function Base.Int(d::Dual)
    all(iszero, partials(d)) || throw(InexactError(:Int, Int, d))
    Int(value(d))
end

Which means that to make this coherent it would instead be
isinteger(x::Dual) = isinteger(x.value) && all(iszero.(x.partials)) ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants