Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

parallel-kalman-jax cpu associative scan is very slow #9

Open
murphyk opened this issue Mar 22, 2022 · 1 comment
Open

parallel-kalman-jax cpu associative scan is very slow #9

murphyk opened this issue Mar 22, 2022 · 1 comment

Comments

@murphyk
Copy link

murphyk commented Mar 22, 2022

You say

It is noteworthy that the parallel version will appear to be much slower due to a slow compilation in JAX. This could be improved by using a different implementation of the associative scan or by fixing the number of levels the way it is done in TensorFlow Probability.

What do you mean by 'fixing the number of levels'?

Screen Shot 2022-03-21 at 8 04 54 PM

@AdrienCorenflos
Copy link
Collaborator

TBH there is not much to do to improve the CPU speed as Blelloch scan requires roughly 3 times the amount of serial work that a simple scan would require. However it's also worth noting that the parallel KF/KS will be slower even without this as it requires "inverting" matrices the size of the latent space, which is (often) bigger than the size of the observation space.

Looking back I think the comparison to TF was a small mistake on my end as the reason why TF has such a utility is for faster compilation in the case of varying length arrays: an impossibility in JAX.

A "real solution" would however be to lower the associative_scan operation to XLA directly (same as for other controlflow operations such as scan and while_loop) so as to bypass most of the compilation run. This would cost a lot of human effort though.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants