-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Outer JIT compilation time could be optimized #9
Comments
You can remove the jit in the gist. We have discussed this. This is part of the reason that It was a The down side is that the backprops of growth, modes, and lpt are not jitted in this approach. |
The quickstart notebook has some timing. |
Oh right right, well the problem is that ultimately there will need to be a jit outside of pmwd. For instance in an hmc sampler or as part a of a larger simulation model that also includes additional computational layers. And concretely for me it's good to have my distributed code inside a jit (not required though). Ok I'll see if I can think of how to do the export with a scan. But I can't imagine it being that much slower? Can you remember by how much it changed things? |
I don't see why HMC requires a top-level jit. Maybe it's more convenient that way. But why mandatory? And what do you have in mind for the the larger model? If it's a really big model like almost any NNs, I remember it was 20-30%. And I think one cannot export inside jit. |
I wonder if it's still slow if you jit |
When you use an external sampler like numpyro or TFP, it will usually compile all the logic of the hmc kernel, including the evaluation of the log likelihood. It may be possible to disable that (and I agree it's not in principle necessary), but by default in JAX the user expects to be able to jit their code without knowledge of the underlying implementation. 20/30% sounds like a lot yeahhh.... I guess it was for small size problems though, but still I see the reason for this tradeoff if that is that bad. And otherwise, yeah I agree that saving snapshots to disk from within a jitted function wouldn't be super trivial. But if you are doing things on the fly that's probably not super important. I can see though that maybe you want to avoid the memory cost of storing intermediate snapshots.... |
Some of those users are likely already used to the compilation speed. I have heard complaints about JAX taking minutes to compile NNs. I don't know why JAX cannot get cache hits on |
A discussion that mentions adding lower level jit helps with compilation time |
I think saving snapshots are also important for normal use cases, like generating mocks. |
In this example:
https://gist.github.com/EiffL/8e46d261e5d52cd28ca81e233fef9b04
It takes 3 mins for the first evaluation of the model to run, but just a few seconds in the second run.
@modichirag has also been able to check that the compilation time is a function of the number of steps. This would indicate that the code is building an overly complex computational graph including explicitly each step of the nbody.
I suspect this is due to using a python for loop in the nbody function. Probably things would improve a lot if it were replaced with a lax.scan
The text was updated successfully, but these errors were encountered: