-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Custom Pydataset does not work with JAX, GPU enabled and multiprocessing enabled #20619
Comments
Hi @edge7, Thanks for reporting this issue.
Training gets stuck: Attaching gist for your reference. You can also refer to this issue. |
hi @sonali-kumari1 thanks for the answer.
If I run this:
and then I set a breakpoint in /usr/lib/python3/multiprocessing/context.py, specifically here: So I can assume spawn will properly be set.
By the way, for CPU-bound transformations and JAX (and GPU), is safe to say , prefer TfDataset instead of PyDataset? |
Hi @edge7, |
I noticed some inconsistencies above, which is a bit frustrating. Let's suppose one has a CPU bound pre-processing operations, which means it would like to use multi-processing, is the following a valuable option:
I am asking because the training works with the Jax context manager, but I am unsure about the side effect, even though I am fairly sure the model is still trained with GPU and just the get_item is placed in CPU. |
I use I'm not familiar with |
Hello,
The below code does not work and training gets stuck:
It works if I don't use the GPU or set
multiprocessing=False
.Enabling
multiprocessing.set_start_method("spawn")
does not help either.This seems to work instead:
Please note:
with jax.default_device(jax.devices("cpu")[0]):
See here for more details.
Is the above context manager the only way to use multi-processing in JAX with GPU?
The text was updated successfully, but these errors were encountered: