Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom Pydataset does not work with JAX, GPU enabled and multiprocessing enabled #20619

Closed
edge7 opened this issue Dec 10, 2024 · 5 comments
Closed
Assignees

Comments

@edge7
Copy link
Contributor

edge7 commented Dec 10, 2024

Hello,
The below code does not work and training gets stuck:

import os
import multiprocessing
import jax

os.environ["KERAS_BACKEND"] = "jax"
import keras


class MyDataset(keras.utils.PyDataset):
    def __init__(self):
        super().__init__(max_queue_size=1500, use_multiprocessing=True, workers=7)

    def __getitem__(self, item):
        print(item)
        return keras.ops.zeros((1, 1)), keras.ops.zeros((1, 10))

    @property
    def num_batches(self):
        return 5


class MyModel(keras.Model):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.dense = keras.layers.Dense(10)

    def call(self, x):
        return self.dense(x)

    def compute_loss(
        self, x=None, y=None, y_pred=None, sample_weight=None, training=True
    ):
        return keras.ops.mean(keras.losses.mean_squared_error(y, y_pred))


if __name__ == "__main__":
    # the below line won't fix the issue
    # multiprocessing.set_start_method("spawn")
    dataset = MyDataset()
    model = MyModel()

    model.compile(optimizer="adam")
    model.fit(dataset, epochs=100)

It works if I don't use the GPU or set multiprocessing=False.
Enabling multiprocessing.set_start_method("spawn") does not help either.

This seems to work instead:

import os
import multiprocessing
import jax

os.environ["KERAS_BACKEND"] = "jax"
import keras
import numpy as np


class MyDataset(keras.utils.PyDataset):
    def __init__(self):
        super().__init__(max_queue_size=1500, use_multiprocessing=True, workers=7)

    def __getitem__(self, item):
        with jax.default_device(jax.devices("cpu")[0]):
            print(item)
            x = keras.ops.zeros((1, 1)), keras.ops.zeros((1, 10))
        return x

    @property
    def num_batches(self):
        return 5


class MyModel(keras.Model):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.dense = keras.layers.Dense(10)

    def call(self, x):
        return self.dense(x)

    def compute_loss(
        self, x=None, y=None, y_pred=None, sample_weight=None, training=True
    ):
        return keras.ops.mean(keras.losses.mean_squared_error(y, y_pred))


if __name__ == "__main__":
    # the below line won't fix the issue
    # multiprocessing.set_start_method("spawn")
    dataset = MyDataset()
    model = MyModel()

    model.compile(optimizer="adam")
    model.fit(dataset, epochs=100)

Please note:
with jax.default_device(jax.devices("cpu")[0]):

See here for more details.

Is the above context manager the only way to use multi-processing in JAX with GPU?

@sonali-kumari1
Copy link
Contributor

Hi @edge7,

Thanks for reporting this issue.
Here is a breakdown of errors/warnings you are facing:

multiprocessing.set_start_method("spawn") not working:
The context is already being set to fork by default which is why spawn is not working.

Training gets stuck:
The training is getting stuck because os.fork() is being called and it is incompatible with multithreaded code, and JAX is multithreaded causing the training to get stuck. You can refer PyDataset class for more information. Setting multiprocessing=False will resolve the issue.

Attaching gist for your reference. You can also refer to this issue.

@edge7
Copy link
Contributor Author

edge7 commented Dec 11, 2024

hi @sonali-kumari1 thanks for the answer.
I have a note about:

"The context is already being set to fork by default which is why spawn is not working."
I don't think this is happening in the above code, as an exception should be raised in that case.

If I run this:

import os
import multiprocessing
os.environ["KERAS_BACKEND"] = "jax"
import keras


class MyDataset(keras.utils.PyDataset):
    def __init__(self):
        super().__init__(max_queue_size=1500, use_multiprocessing=True, workers=7)

    def __getitem__(self, item):
        print(item)
        return keras.ops.zeros((1, 1)), keras.ops.zeros((1, 10))

    @property
    def num_batches(self):
        return 5


class MyModel(keras.Model):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.dense = keras.layers.Dense(10)

    def call(self, x):
        return self.dense(x)

    def compute_loss(
        self, x=None, y=None, y_pred=None, sample_weight=None, training=True
    ):
        return keras.ops.mean(keras.losses.mean_squared_error(y, y_pred))


if __name__ == "__main__":
    # the below line won't fix the issue
    multiprocessing.set_start_method("spawn")
    dataset = MyDataset()
    model = MyModel()

    model.compile(optimizer="adam")
    model.fit(dataset, epochs=100)

and then I set a breakpoint in /usr/lib/python3/multiprocessing/context.py, specifically here:

image
then:
image

So I can assume spawn will properly be set.
I do believe you are setting it in the wrong place. It needs to be set inside the main
By the way still, I get:

RuntimeError: Unable to initialize backend 'cuda': INTERNAL: no supported devices found for platform CUDA (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)

By the way, for CPU-bound transformations and JAX (and GPU), is safe to say , prefer TfDataset instead of PyDataset?

@sonali-kumari1
Copy link
Contributor

Hi @edge7,
You are right that multiprocessing.set_start_method("spawn") inside main will set the context properly to spawn and setting os.environ["CUDA_VISIBLE_DEVICES"] = "0" and multiprocessing=False will resolve the runtime error. Attaching gist for your reference.
PyDataset is a better way to do multiprocessing and you can find more details about JAX parallelism here.

@edge7
Copy link
Contributor Author

edge7 commented Dec 12, 2024

I noticed some inconsistencies above, which is a bit frustrating.
@james77777778 may I ask for your informed opinion on:

Let's suppose one has a CPU bound pre-processing operations, which means it would like to use multi-processing, is the following a valuable option:

import os
import multiprocessing

os.environ["KERAS_BACKEND"] = "jax"
import jax
import keras


class MyDataset(keras.utils.PyDataset):
    def __init__(self):
        super().__init__(max_queue_size=1500, use_multiprocessing=True, workers=7)

    def __getitem__(self, item):
        print(item)
        with jax.default_device(jax.devices("cpu")[0]):
            return keras.ops.zeros((1, 1)), keras.ops.zeros((1, 10))

    @property
    def num_batches(self):
        return 5


class MyModel(keras.Model):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.dense = keras.layers.Dense(10)

    def call(self, x):
        return self.dense(x)

    def compute_loss(
        self, x=None, y=None, y_pred=None, sample_weight=None, training=True
    ):
        return keras.ops.mean(keras.losses.mean_squared_error(y, y_pred))


if __name__ == "__main__":
    multiprocessing.set_start_method("spawn")
    print(multiprocessing.get_start_method())
    dataset = MyDataset()
    model = MyModel()

    model.compile(optimizer="adam")
    model.fit(dataset, epochs=100)

I am asking because the training works with the Jax context manager, but I am unsure about the side effect, even though I am fairly sure the model is still trained with GPU and just the get_item is placed in CPU.
Or maybe for situations like this (JAX and CPU Bound pre-processing on the fly), it's just better to use tf.dataset?

@james77777778
Copy link
Contributor

I use tf.data for its performance and torch.utils.data.DataLoader for its flexibility. This is just my personal preference, as @edge7 pinged me.

I'm not familiar with PyDataset workflow.

@edge7 edge7 closed this as completed Dec 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants