You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
It is noteworthy that the parallel version will appear to be much slower due to a slow compilation in JAX. This could be improved by using a different implementation of the associative scan or by fixing the number of levels the way it is done in TensorFlow Probability.
What do you mean by 'fixing the number of levels'?
The text was updated successfully, but these errors were encountered:
TBH there is not much to do to improve the CPU speed as Blelloch scan requires roughly 3 times the amount of serial work that a simple scan would require. However it's also worth noting that the parallel KF/KS will be slower even without this as it requires "inverting" matrices the size of the latent space, which is (often) bigger than the size of the observation space.
Looking back I think the comparison to TF was a small mistake on my end as the reason why TF has such a utility is for faster compilation in the case of varying length arrays: an impossibility in JAX.
A "real solution" would however be to lower the associative_scan operation to XLA directly (same as for other controlflow operations such as scan and while_loop) so as to bypass most of the compilation run. This would cost a lot of human effort though.
You say
It is noteworthy that the parallel version will appear to be much slower due to a slow compilation in JAX. This could be improved by using a different implementation of the associative scan or by fixing the number of levels the way it is done in TensorFlow Probability.
What do you mean by 'fixing the number of levels'?
The text was updated successfully, but these errors were encountered: