Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add special case for derivative of modified_bessel_function(0,x) to greatly improve model estimation speed #3009

Open
venpopov opened this issue Jan 20, 2024 · 1 comment

Comments

@venpopov
Copy link
Contributor

Description

As described here, evaluating the derivative of modified_bessel_function(0,x) can be sped up dramatically with a simple change. Issue 3008 described how to do that for the von_mises_lpdf, where the derivative is hand-coded, but that solution won't apply to custom models that use the modified_bessel_function(0,x).

In the forward and reverse passes, the derivative of the modified_bessel_function(v,x) of general order v is calculated as:

$$ \frac{\delta I_v(x)}{\delta x} = I_{v-1}(x) - \frac{v}{x}I_v(x) $$

For $I_0(x)$, this results in the calculation:

$$ \frac{\delta I_0(x)}{\delta x} = I_{-1}(x) - \frac{0}{x}I_0(x) $$

Since, $I_{-1}(x) = I_{1}(x)$ and the second term is 0 we have (see 10.29.3):

$$ \frac{\delta I_0(x)}{\delta x} = I_{1}(x) $$

As described here, calculating modified_bessel_function(1,x) is about 10 times faster than calculating modified_bessel_function(-1,x). Thus, the above code while applicable for any order, results in very inefficient calculation for models that use modified_bessel_function(0,x), which is the most common order (at least in my field). This is because it unnecessarily calculates $I_0(x)$, even though this terms disappears, and it calculate $I_{-1}(x) instead of $I_1(x)$

Example

In a model I'm currently building, which has the likelihood:

$$ f(\theta, c, k) = exp\bigg(\frac{c\ exp(y\ cos(\theta))}{2 \pi I_0(y)}\bigg)/Z(c,y) $$

after many other optimizations, now 90% of the time is spent in calculating $I_0(y)$. E.g., using the profile function of the cmdstanr package:

functions {
  real sdm_lpdf(vector y, vector mu, vector kappa) {
    profile("lpdf_be") {
      be = modified_bessel_first_kind(0, kappa);
    }
    // code for calculating the rest of the likelihood
    }
}

// other code

model {
  // other code
  profile("model_lpdf_total") {
     target += sdm_lpdf(Y | mu, kappa);
  }
 // other code
}

shows

                     name thread_id total_time forward_time reverse_time chain_stack no_chain_stack autodiff_calls no_autodiff_calls
1        model_lpdf_total         1   1233.100   76.9433000  1.15616e+03  1462836268     1462164925          60918                 1
2                 lpdf_be         1   1201.490   52.2222000  1.14927e+03   731053200              0          60918                 1

and the vast majority of that time is the reverse autodiff pass

Requested change

I envision two possibilities:

  1. add a conditional statement to the fwd and rev passes that handles the derivative of the special case of modified_bessel_function(0,x)

  2. replace the derivative formula with

$$ \frac{\delta I_v(x)}{\delta x} = I_{v+1}(x) + \frac{v}{x}I_v(x) $$

which is equivalent (see 10.29.2) to the current statement, but will avoid the inneficient calculation for negative order. The downside of this option is two-fold - first, it still calculates the an extra bessel function, even thought it will be canceled by multiplication by 0 (is this correct? I'm not sure how autodiff handles such cases). Second, it will make the derivative of models that use $I_1(x)$ less efficient instead.

When I rerun the model with manually changing my stan installation code with option 1, I see ~4 times faster estimation of the model (e.g. from 11h down to 3h!)

Expected Output

For option 1), a possible implementation is to change the following in rev

    bvi_->adj_
        += adj_
           * (-ad_ * modified_bessel_first_kind(ad_, bvi_->val_) / bvi_->val_
              + modified_bessel_first_kind(ad_ - 1, bvi_->val_));

to (described with pseudo code for conditional statements, because I don't know what is the most efficient way to code that - is it a simple if {}... else {}?)

    // if ad_ == 0
    bvi_->adj_
        += adj_ * modified_bessel_first_kind(1, bvi_->val_);

    // else
    bvi_->adj_
        += adj_
           * (-ad_ * modified_bessel_first_kind(ad_, bvi_->val_) / bvi_->val_
              + modified_bessel_first_kind(ad_ - 1, bvi_->val_));

And change the following in fwd:

  // if v == 0
  return fvar<T>(z.d_ * modified_bessel_first_kind(1, z.val_));

  // else
  return fvar<T>(modified_bessel_first_kind_z,
                 -v * z.d_ * modified_bessel_first_kind_z / z.val_
                     + z.d_ * modified_bessel_first_kind(v - 1, z.val_));

Current Version:

v4.8.0

@venpopov
Copy link
Contributor Author

@andrjohns I can try to implement this after our discussion in the other issue. Do you think that the conditional approach checking if the order of the bessel function is 0 is appropriate?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant