Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GKE A3 Ultra support #940

Closed
wants to merge 21 commits into from
Closed

Add GKE A3 Ultra support #940

wants to merge 21 commits into from

Conversation

samos123
Copy link
Contributor

  • Adds support for easily adding more GKE GPU accelerators by using a base class
  • Adds Fuji v2 70B benchmark results

@samos123 samos123 requested review from ruomingp, markblee and a team as code owners January 22, 2025 00:23
@samos123 samos123 marked this pull request as draft January 22, 2025 00:23
Dockerfile Outdated
@@ -101,14 +93,63 @@ COPY . .
# GPU container spec. #
################################################################################

FROM base AS gpu
# This causes INTERNAL: No valid engine configs for Matmul error
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@markblee @kelvin-zou are you fine with moving to nvidia/cuda:12.6.3-cudnn-devel-ubuntu22.04 as the base image. I couldn't get A3 Ultra working on the original python base image.

Dockerfile Outdated

# TODO(markblee): Support extras.
ENV PIP_FIND_LINKS=https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
RUN pip install .[core,gpu]
COPY . .
RUN pip install -U "jax[gpu]==0.4.38" "jax==0.4.38" "jaxlib==0.4.38" \
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: this should be removed. Maybe we should wait until axlearn main upgrades to 0.4.38.

Copy link
Contributor

@ruomingp ruomingp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Dockerfile Outdated
# So we're using the NVIDIA provided cuda image instead which works.
FROM nvidia/cuda:12.6.3-cudnn-devel-ubuntu22.04 as gpu

# Copy from original base
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it mean that we will need to keep the following commands consistent with those in BASE? Does Dockerfile support functions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did some research and sadly the answer is no. We could try to create a bash script that gets used in both instead of copy pasting code. I do agree that we should figure out a way to reuse the same setup steps.

Dockerfile Outdated
Comment on lines 100 to 102
# Using `FROM base as GPU` causes INTERNAL: No valid engine configs for Matmul error.
# So we're using the NVIDIA provided cuda image instead which works.
FROM nvidia/cuda:12.6.3-cudnn-devel-ubuntu22.04 as gpu
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make sure this change doesn't break our GPU training on AWS.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fine but defer to @kelvin-zou

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Second to what Ruoming recommended, maybe split it into Dockerfile.gcp if there are some gcp specific things. and import this docker file as base. The base image import is fine though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There isn't anything GCP specific except installation of gcloud which also happens in the axlearn base image. Happy to move to Dockerfile.gcp though since it uses a different base image. I think that makes sense.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually think it's better to keep as-is, we don't want 2 separate GPU images, 1 broken and 1 working. At some point, it may be helpful to switch to Jax Stable Stack image so we can re-use the same validated base image on both GPU and TPU.

Are you fine with keeping as-is since there is no GCP specific things happening in the gpu image?

@samos123
Copy link
Contributor Author

samos123 commented Feb 5, 2025

I still have access to internal capacity so wanted to get some reviews and make changes before I lose the ability to verify any changes.

@samos123 samos123 marked this pull request as ready for review February 5, 2025 00:08
Dockerfile Outdated

# TODO(markblee): Support extras.
ENV PIP_FIND_LINKS=https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
RUN pip install .[core,gpu]
COPY . .

# TODO(samos123): remove this once axlearn upgrades to Jax 0.4.38.
RUN pip install -U "jax[gpu]==0.4.38" "jax==0.4.38" "jaxlib==0.4.38"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we rely on axlearn's jax version instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes will remove this before merge since I think 0.4.37 works as well.

Dockerfile Outdated
Comment on lines 100 to 102
# Using `FROM base as GPU` causes INTERNAL: No valid engine configs for Matmul error.
# So we're using the NVIDIA provided cuda image instead which works.
FROM nvidia/cuda:12.6.3-cudnn-devel-ubuntu22.04 as gpu
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Second to what Ruoming recommended, maybe split it into Dockerfile.gcp if there are some gcp specific things. and import this docker file as base. The base image import is fine though.

@samos123
Copy link
Contributor Author

samos123 commented Mar 3, 2025

I was able to get it working with original Docker image for GPU. I think the newer Jax and NVIDIA drivers may have removed the need to use the CUDA image from NVIDIA.

@markblee @kelvin-zou could you please review again and get this merged? We've people ready to start adding A4 support and benchmark that. However that work builds on top of this PR. So trying to get this merged ASAP.

I also plan to write an end to end tutorial to train models using AXLearn on GPU so any GCP user can take advantage of AXLEarn. Also those without TPUs.

@samos123 samos123 requested review from kelvin-zou and ruomingp March 3, 2025 21:00
@samos123
Copy link
Contributor Author

samos123 commented Mar 5, 2025

@markblee @kelvin-zou friendly ping. We're starting A4 benchmarking now as well. We are using this PR as the base for that work. So would like to get it merged.

@kelvinzou
Copy link

@markblee @kelvin-zou friendly ping. We're starting A4 benchmarking now as well. We are using this PR as the base for that work. So would like to get it merged.

Thanks for updating this PR!

two quick comments

  1. we don't really use the axlearn to submit GCP GPU jobs, so this is more like a PoC.
  2. given that, can you actually peel off the part that is unique to A3-Ultra? This would be great guidance to us if we ever want to use A3-Ultra. IIUC, comparing to A3-Mega, there is the mlx7 replacing Google NIC and H200 not H100, and the main diff should be only this?

Copy link
Contributor

@markblee markblee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Mostly LGTM. Will let @Ethanlm @kelvinzou leave a review.

Comment on lines 306 to 309
@classmethod
def with_builder(cls, builder: type[BaseReplicatedJob]):
cls.builder = builder
return cls
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you follow the pattern here to avoid in-place update:

@classmethod
def with_runner(cls, runner: type[Job]):
return type(f"{cls.__name__}_{runner.__name__}", (cls,), {"runner": runner})

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor Author

@samos123 samos123 Mar 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tests are failing after making this change. I will do follow up.

AttributeError: type object 'GPUGKEJob' has no attribute 'builder'

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be fixed now. Tested both gke_runner and job.

@@ -143,6 +147,11 @@ def validate_inner(cls):
if cls.inner is None:
raise ValueError(f"A GKERunnerJob should subclass {cls} and define `inner`.")

@classmethod
def with_inner(cls, inner: type[GKEJob]):
cls.inner = inner
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Same as above.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines +546 to +549
elif instance_type.startswith("gpu-a3-ultra"):
return GKERunnerJob.with_inner(GKEJob.with_builder(A3UltraReplicatedJob))
elif instance_type.startswith("gpu-a3-high"):
return GKERunnerJob.with_inner(GKEJob.with_builder(A3HighReplicatedJob))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Neat!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is like praising yourself since you came up with it lol

@markblee markblee requested a review from Ethanlm March 5, 2025 18:29
@samos123
Copy link
Contributor Author

samos123 commented Mar 5, 2025

  1. given that, can you actually peel off the part that is unique to A3-Ultra? This would be great guidance to us if we ever want to use A3-Ultra. IIUC, comparing to A3-Mega, there is the mlx7 replacing Google NIC and H200 not H100, and the main diff should be only this?

Correct about difference between A3 Mega and A3 Ultra. Note that total bandwidth is different to A3 Mega (1.6 Gbps) and A3 Ultra (3.2 Gbps).

I think that's what I did in this PR, only separating out the things that are unique to A3 Ultra.

Both A3 High and A3 Mega would require their own classes since they have different sidecars and main container specs. I purposely created a base class in a way that A4 would only have to define a new main container. A3 High and A3 Mega both need sidecars and specific main container.

I think there may be a chance to re-use the same class for both A3 Ultra and A4, but A4X is likely going to need it's own class once again since it's using Arm.

Note we may also build automation to test a multi-node training job on GPU on our side. That would also use this PR.

I will make a follow up commit addressing Mark's comments soon.

@samos123 samos123 requested a review from markblee March 5, 2025 19:31
# NCCL flags needed
env_vars.update(
{
# Enable auto PGLE available in jax 0.4.33
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it still apply if we are already on jax 0.4.38

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed this comment

k8s_env_vars = [{"name": name, "value": value} for name, value in env_vars.items()]
k8s_env_vars.append(
{
"name": "PROCESS_ID",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate how is PROCESS_ID used?
I see this is how it was written in the original code. Just for my education.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's an AXLearn specific thing:

os.environ.get("PROCESS_ID", None),

flags.DEFINE_integer(
    "process_id",
    os.environ.get("PROCESS_ID", None),
    "Rank of the current process. Must be None on tpu, otherwise required.",
)

We set at as env variable so it works out of the box.

volume_mounts = [
{"name": "shared-memory", "mountPath": "/dev/shm"},
{"name": "nvidia-install-dir-host", "mountPath": "/usr/local/nvidia/lib64"},
{"name": "gib", "mountPath": "/usr/local/gib"},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, the major diff between A3Ultra and A3High are differnt volume mounts.
Could you please elaborate what these different mounts mean and why? Thanks

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In A3 High we have a tcpx volume to share a socket.

In A3 Ultra there is a gib volume that's mounted from GKE node. The gib volume provides configs needed for CX7 and some binaries:

                "NCCL_SHIMNET_GUEST_CONFIG_CHECKER_CONFIG_FILE": (
                    "/usr/local/gib/configs/guest_config.txtpb"
                ),
                "NCCL_TUNER_CONFIG_PATH": "/usr/local/gib/configs/tuner_config.txtpb",

@samos123
Copy link
Contributor Author

samos123 commented Mar 7, 2025

@kelvin-zou could you do another review and Approve or Request changes? I responded to your comment but did not make any changes based on your comment.

Copy link
Contributor

@kelvin-zou kelvin-zou left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, the change looks great overall, some small suggestions.

@@ -99,6 +99,9 @@ COPY . .

FROM base AS gpu

# Needed for NVIDIA CX7 based RDMA (not cloud specific).
RUN apt-get update && apt-get install -y ibverbs-utils
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

# NCCL flags needed
env_vars.update(
{
"JAX_ENABLE_PGLE": "True",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's turn off PGLE, PGLE has some accuracy issues which we haven't been able to triage.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed, but will impact performance.

Our Jax GPU team would love to improve that. I do understand that you probably don't have bandwidth to triage right now. Let me know if you feel there is an easy way for me to reproduce in AXLearn and I Can have someone take that on.

# This is needed for flash attention + auto PGLE to work
"JAX_REMOVE_CUSTOM_PARTITIONING_PTR_FROM_CACHE_KEY": "True",
"CUDA_DEVICE_MAX_CONNECTIONS": "1",
"NVTE_FUSED_ATTN": "1",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not useful, we are not using TE, instead cudnn which doens't require this flag.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed. Great catch! This was coming from Maxtext which I believe was using TE. I was unaware that this only impacted TE.

"NCCL_SOCKET_IFNAME": "=eth0,eth1",
"NCCL_CROSS_NIC": "0",
"NCCL_NET_GDR_LEVEL": "PIX",
"NCCL_P2P_NET_CHUNKSIZE": "65536",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where did all TCPx flags go?

Also maybe nice if we can break it into 3 parts

  1. nccl flags
  2. tcpx flags
  3. mlx flags.
    Fwiw, this is what we do internally, we also have EFA flags, and we structure them in blocks so we can compose them

Copy link
Contributor Author

@samos123 samos123 Mar 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tcpx flags are only relevant on A3 High and they're still in that class. On A3 Mega it's yet another set of NCCL TCPXO flags (note that A3 Mega is not supported in AXLearn today). I have a separate branch for a3 mega still: main...samos123:axlearn:a3-mega-support

On A3 Ultra and A4 we also use different NCCL flags. It looks like for different accelerators we're recommending different flags as the default. There is less of a difference between A3 Ultra and A4 however A4X is yet another big change because now we also have to deal with Arm and lower amount of GPUs per node.

On GCP right now it's better to have per machine type classes because they are turning out to be so different.

I do like the idea of using recomposable NCCL flags. Some of the XLA and NCCL flags could be re-usable potentially between GPU accelerators. I suggest we make that change as part of the A4 PR so we have a better view of how different the flags are. There will be an A4 PR very soon :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kelvin-zou I am fine to make it in this PR too if you feel strongly about it. I do see the value in it, just prefer deferring it to A4 work since that's basing off this PR right now.

@samos123 samos123 requested a review from kelvin-zou March 10, 2025 22:55
@samos123
Copy link
Contributor Author

We would really like to get this merged since main keeps on having significant changes to these files. This is critical for Google to be able to run AXLearn benchmarks. Can you please confirm whether you're fine with the changes overall so I can rebase and handle the conflicts?

We have an A4/B200 PR that will be submitted soon after this PR.

@samos123
Copy link
Contributor Author

samos123 commented Apr 4, 2025

Closing this in favor for #1097 which adds A3 Ultra, A4 and A3 Mega support.

@samos123 samos123 closed this Apr 4, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants