Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fsdp=16 model=16 gbs=16 should work on 256 chips #773

Open
samos123 opened this issue Oct 22, 2024 · 9 comments
Open

fsdp=16 model=16 gbs=16 should work on 256 chips #773

samos123 opened this issue Oct 22, 2024 · 9 comments

Comments

@samos123
Copy link
Contributor

samos123 commented Oct 22, 2024

fsdp=16 model=16 global_batch_size=16 should work on 256 chips

The use case is being able to use a global batch size smaller than total jax processes.

This is supported in maxtext by using this trick: https://github.com/AI-Hypercomputer/maxtext/blob/4cf51b7f204e109df502cf2d54b4d5005f597b09/MaxText/train.py#L289-L291

Trying to get 405b model running on v6e-256 (fsdp=16 model=16) but getting hit with this error:

I1022 20:32:33.715831 139189201369088 trainer.py:323] gpt_trainer process  19 step       -1] Global mesh: Mesh('pipeline': 1, 'data': 1, 'expert': 1, 'fsdp': 16, 'seq': 1, 'model': 16)
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/local/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/root/axlearn/common/launch_trainer_main.py", line 21, in <module>
    app.run(main)
  File "/opt/venv/lib/python3.10/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/opt/venv/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/root/axlearn/common/launch_trainer_main.py", line 16, in main
    launch_trainer.run_trainer(trainer_config)
  File "/root/axlearn/common/launch_trainer.py", line 129, in run_trainer
    trainer: SpmdTrainer = trainer_config.instantiate(parent=None)
  File "/root/axlearn/common/config.py", line 734, in instantiate
    return self.klass(self, **kwargs)
  File "/root/axlearn/common/module.py", line 520, in __call__
    instance = super().__call__(*args, **kwds)
  File "/root/axlearn/common/trainer.py", line 244, in __init__
    self._add_child("input", cfg.input.set(is_training=True))
  File "/root/axlearn/common/module.py", line 760, in _add_child
    module = child_config.instantiate(parent=self, **kwargs)
  File "/root/axlearn/common/config.py", line 734, in instantiate
    return self.klass(self, **kwargs)
  File "/root/axlearn/common/module.py", line 520, in __call__
    instance = super().__call__(*args, **kwds)
  File "/root/axlearn/common/input_tf_data.py", line 1185, in __init__
    self._batcher = maybe_set_config(cfg.batcher, is_training=cfg.is_training).instantiate()
  File "/root/axlearn/common/config.py", line 801, in instantiate
    return self.fn(*args, **kwargs)
  File "/root/axlearn/common/input_tf_data.py", line 799, in batch
    raise ValueError(
ValueError: global_batch_size (16.0) must be divisible by number of JAX processes (data feeds) (64).
@samos123 samos123 changed the title fsdp=16, model=16 and gbs=16 should work on 256 chips fsdp=16 model=16 gbs=16 should work on 256 chips Oct 22, 2024
@markblee
Copy link
Contributor

Hi @samos123 , you can use the input dispatcher:

# If not None, creates an InputDispatcher and uses it for dispatching per-feed batches to
# global batches.
input_dispatcher: Optional[InputDispatcher] = None

class InputDispatcher(Module):
"""A Module to dispatch per-feed logical input batches to global logical batches on device.
The dispatch process consists of three steps:
- Convert each logical feed batch to a physical feed batch (logical_to_physical_batch);
- Assemble a global physical batch from per-feed batches (utils.host_to_global_device_array);
- Convert a global physical batch to a global logical batch (physical_to_logical_batch).
This process is needed because utils.host_to_global_device_array requires that global batch
size be divisible by number of devices.
One should set up the local input generator to read the logical shard as specified by
`feed_read_config` and batch the examples by `feed_logical_batch_size`.
One should then call `logical_to_physical_batch` on each per-feed batch, followed by
`utils.host_to_global_device_array` to generate the input array for pjit, then finally
`physical_to_logical_batch` inside pjit.
"""

Some hosts will produce padding feeds which will be dropped during input dispatch. I have some ideas to make this a bit simpler soon, but this should unblock you for now.

@samos123
Copy link
Contributor Author

samos123 commented Oct 23, 2024

Maybe I'm misunderstanding the code... but input_dispatcher is None for the fuji models, so wouldn't it default to InputDispatcher already? @markblee

@samos123
Copy link
Contributor Author

Are you saying I should create a custom InputDispatcher and pass that instead? That may make sense. Looking into that further.

@samos123
Copy link
Contributor Author

samos123 commented Oct 23, 2024

Would these be the right settings for fsdp=16 and model=16 gbs=16 on v6e-256?

        # Usually left unset. Defaults to
        # max(feed_logical_batch_size * num_physical_feeds, jax.device_count()).
        global_physical_batch_size = 16
  
        # The total number of physical feeds across all hosts. Defaults to jax.process_count().
        num_physical_feeds = 64

        # The local physical feed index. Must be in [0, num_physical_feeds).
        # Defaults to jax.process_index().
        physical_feed_index: Optional[int] = None

@samos123
Copy link
Contributor Author

Currently in Fuji models this is set:

        cfg.input = input_tf_data.Input.default_config().set(
            is_training=True,
            source=train_input_source,
            processor=config_for_function(input_tf_data.identity),
            batcher=config_for_function(input_tf_data.batch).set(
                global_batch_size=train_batch_size,
                prefetch_buffer_size=tf.data.AUTOTUNE,
                pad_example_fn=input_tf_data.default_pad_example_fn,
            ),
        )

This is what's inside input_tf_data.batch function:

    num_data_feeds = jax.process_count()
    if global_batch_size % num_data_feeds != 0:
        raise ValueError(
            f"global_batch_size ({global_batch_size}) must be divisible by "
            f"number of JAX processes (data feeds) ({num_data_feeds})."
        )

So I suspect I need to modify the batch function directly.

@kelvin-zou
Copy link
Contributor

@samos123 please read through the input logic, there is a logical batch, and physical batch, please understand the two key logics there and you should be good to go.

@hanzhi713
Copy link
Member

hanzhi713 commented Oct 23, 2024

@samos123

global_logical_batch_size = 16
global_physical_batch_size = 256 # You can also try 64 here to see if it works.
logical_feed_indices = list(range(0, 64, 4)) # 1 in 4 host read a single "real" batch.

You don't need to specify other fields because the batcher can infer them. Also like Kelvin said, be sure to understand the meaning of these fields.

@samos123
Copy link
Contributor Author

Just sharing for now since it's related. I also hit this error when trying fsdp=16, mdoel=16 and gbs=128 on 256 chips:

Stack Summary (most recent call last):                                                                                     File "/usr/local/lib/python3.10/runpy.py", line 86, in _run_code                                                           exec(code, run_globals)
  File "/root/axlearn/common/launch_trainer_main.py", line 21, in <module>
    app.run(main)
  File "/opt/venv/lib/python3.10/site-packages/absl/app.py", line 330, in run                                                raise                                                                                                                  File "/opt/venv/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/root/axlearn/common/launch_trainer_main.py", line 16, in main
    launch_trainer.run_trainer(trainer_config)
  File "/root/axlearn/common/launch_trainer.py", line 131, in run_trainer                                                    output = trainer.run(prng_key)
  Wrapped call axlearn.common.trainer.SpmdTrainer.run(jaxlib.xla_extension.ArrayImpl)                                      File "/root/axlearn/common/trainer.py", line 501, in run
    utils.host_to_global_device_array(input_batch),                                                                        File "/root/axlearn/common/utils.py", line 653, in host_to_global_device_array
    device_arrays = jax.tree.map(put_to_devices, host_arrays)                                                              File "/opt/venv/lib/python3.10/site-packages/jax/_src/tree.py", line 155, in map
    return tree_util.tree_map(f, tree, *rest, is_leaf=is_leaf)
  File "/root/axlearn/common/utils.py", line 637, in put_to_devices_fully_partitioned
    raise ValueError(f"({x.shape}) cannot be sharded across {len_local_devices} devices.")
ValueError: ((2, 8192)) cannot be sharded across 4 devices.

@kelvin-zou
Copy link
Contributor

@samos123 please understand the physical branch logic, you shouldn't do global bs=128 over 256 chips, it is always 256. Logical batch is something you care more about, and physical batch can be auto configured when we use logical, but we didn't have that in place yet.

fwiw, you also cannot do model==16 is the num_heads, which is 8 for kv in GQA model. But that's a separate issue and you haven't hit that yet.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants