-
Notifications
You must be signed in to change notification settings - Fork 303
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add GKE A3 Ultra support #940
Conversation
samos123
commented
Jan 22, 2025
- Adds support for easily adding more GKE GPU accelerators by using a base class
- Adds Fuji v2 70B benchmark results
Dockerfile
Outdated
@@ -101,14 +93,63 @@ COPY . . | |||
# GPU container spec. # | |||
################################################################################ | |||
|
|||
FROM base AS gpu | |||
# This causes INTERNAL: No valid engine configs for Matmul error |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@markblee @kelvin-zou are you fine with moving to nvidia/cuda:12.6.3-cudnn-devel-ubuntu22.04 as the base image. I couldn't get A3 Ultra working on the original python base image.
Dockerfile
Outdated
|
||
# TODO(markblee): Support extras. | ||
ENV PIP_FIND_LINKS=https://storage.googleapis.com/jax-releases/jax_cuda_releases.html | ||
RUN pip install .[core,gpu] | ||
COPY . . | ||
RUN pip install -U "jax[gpu]==0.4.38" "jax==0.4.38" "jaxlib==0.4.38" \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: this should be removed. Maybe we should wait until axlearn main upgrades to 0.4.38.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
Dockerfile
Outdated
# So we're using the NVIDIA provided cuda image instead which works. | ||
FROM nvidia/cuda:12.6.3-cudnn-devel-ubuntu22.04 as gpu | ||
|
||
# Copy from original base |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it mean that we will need to keep the following commands consistent with those in BASE? Does Dockerfile support functions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did some research and sadly the answer is no. We could try to create a bash script that gets used in both instead of copy pasting code. I do agree that we should figure out a way to reuse the same setup steps.
Dockerfile
Outdated
# Using `FROM base as GPU` causes INTERNAL: No valid engine configs for Matmul error. | ||
# So we're using the NVIDIA provided cuda image instead which works. | ||
FROM nvidia/cuda:12.6.3-cudnn-devel-ubuntu22.04 as gpu |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's make sure this change doesn't break our GPU training on AWS.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's fine but defer to @kelvin-zou
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Second to what Ruoming recommended, maybe split it into Dockerfile.gcp if there are some gcp specific things. and import this docker file as base. The base image import is fine though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There isn't anything GCP specific except installation of gcloud
which also happens in the axlearn base image. Happy to move to Dockerfile.gcp
though since it uses a different base image. I think that makes sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually think it's better to keep as-is, we don't want 2 separate GPU images, 1 broken and 1 working. At some point, it may be helpful to switch to Jax Stable Stack image so we can re-use the same validated base image on both GPU and TPU.
Are you fine with keeping as-is since there is no GCP specific things happening in the gpu image?
I still have access to internal capacity so wanted to get some reviews and make changes before I lose the ability to verify any changes. |
Dockerfile
Outdated
|
||
# TODO(markblee): Support extras. | ||
ENV PIP_FIND_LINKS=https://storage.googleapis.com/jax-releases/jax_cuda_releases.html | ||
RUN pip install .[core,gpu] | ||
COPY . . | ||
|
||
# TODO(samos123): remove this once axlearn upgrades to Jax 0.4.38. | ||
RUN pip install -U "jax[gpu]==0.4.38" "jax==0.4.38" "jaxlib==0.4.38" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we rely on axlearn's jax version instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes will remove this before merge since I think 0.4.37 works as well.
Dockerfile
Outdated
# Using `FROM base as GPU` causes INTERNAL: No valid engine configs for Matmul error. | ||
# So we're using the NVIDIA provided cuda image instead which works. | ||
FROM nvidia/cuda:12.6.3-cudnn-devel-ubuntu22.04 as gpu |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Second to what Ruoming recommended, maybe split it into Dockerfile.gcp if there are some gcp specific things. and import this docker file as base. The base image import is fine though.
also refactored a lot of code, possibly broken right now
I was able to get it working with original Docker image for GPU. I think the newer Jax and NVIDIA drivers may have removed the need to use the CUDA image from NVIDIA. @markblee @kelvin-zou could you please review again and get this merged? We've people ready to start adding A4 support and benchmark that. However that work builds on top of this PR. So trying to get this merged ASAP. I also plan to write an end to end tutorial to train models using AXLearn on GPU so any GCP user can take advantage of AXLEarn. Also those without TPUs. |
@markblee @kelvin-zou friendly ping. We're starting A4 benchmarking now as well. We are using this PR as the base for that work. So would like to get it merged. |
Thanks for updating this PR! two quick comments
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Mostly LGTM. Will let @Ethanlm @kelvinzou leave a review.
axlearn/cloud/gcp/job.py
Outdated
@classmethod | ||
def with_builder(cls, builder: type[BaseReplicatedJob]): | ||
cls.builder = builder | ||
return cls |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you follow the pattern here to avoid in-place update:
axlearn/axlearn/cloud/gcp/jobs/launch.py
Lines 214 to 216 in dcfdbc1
@classmethod | |
def with_runner(cls, runner: type[Job]): | |
return type(f"{cls.__name__}_{runner.__name__}", (cls,), {"runner": runner}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tests are failing after making this change. I will do follow up.
AttributeError: type object 'GPUGKEJob' has no attribute 'builder'
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should be fixed now. Tested both gke_runner and job.
axlearn/cloud/gcp/jobs/gke_runner.py
Outdated
@@ -143,6 +147,11 @@ def validate_inner(cls): | |||
if cls.inner is None: | |||
raise ValueError(f"A GKERunnerJob should subclass {cls} and define `inner`.") | |||
|
|||
@classmethod | |||
def with_inner(cls, inner: type[GKEJob]): | |||
cls.inner = inner |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Same as above.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
elif instance_type.startswith("gpu-a3-ultra"): | ||
return GKERunnerJob.with_inner(GKEJob.with_builder(A3UltraReplicatedJob)) | ||
elif instance_type.startswith("gpu-a3-high"): | ||
return GKERunnerJob.with_inner(GKEJob.with_builder(A3HighReplicatedJob)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Neat!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is like praising yourself since you came up with it lol
Correct about difference between A3 Mega and A3 Ultra. Note that total bandwidth is different to A3 Mega (1.6 Gbps) and A3 Ultra (3.2 Gbps). I think that's what I did in this PR, only separating out the things that are unique to A3 Ultra. Both A3 High and A3 Mega would require their own classes since they have different sidecars and main container specs. I purposely created a base class in a way that A4 would only have to define a new main container. A3 High and A3 Mega both need sidecars and specific main container. I think there may be a chance to re-use the same class for both A3 Ultra and A4, but A4X is likely going to need it's own class once again since it's using Arm. Note we may also build automation to test a multi-node training job on GPU on our side. That would also use this PR. I will make a follow up commit addressing Mark's comments soon. |
axlearn/cloud/gcp/jobset_utils.py
Outdated
# NCCL flags needed | ||
env_vars.update( | ||
{ | ||
# Enable auto PGLE available in jax 0.4.33 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it still apply if we are already on jax 0.4.38
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed this comment
k8s_env_vars = [{"name": name, "value": value} for name, value in env_vars.items()] | ||
k8s_env_vars.append( | ||
{ | ||
"name": "PROCESS_ID", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you elaborate how is PROCESS_ID
used?
I see this is how it was written in the original code. Just for my education.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's an AXLearn specific thing:
axlearn/axlearn/common/launch.py
Line 78 in b1e7b37
os.environ.get("PROCESS_ID", None), |
flags.DEFINE_integer(
"process_id",
os.environ.get("PROCESS_ID", None),
"Rank of the current process. Must be None on tpu, otherwise required.",
)
We set at as env variable so it works out of the box.
volume_mounts = [ | ||
{"name": "shared-memory", "mountPath": "/dev/shm"}, | ||
{"name": "nvidia-install-dir-host", "mountPath": "/usr/local/nvidia/lib64"}, | ||
{"name": "gib", "mountPath": "/usr/local/gib"}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I understand correctly, the major diff between A3Ultra and A3High are differnt volume mounts.
Could you please elaborate what these different mounts mean and why? Thanks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In A3 High we have a tcpx volume to share a socket.
In A3 Ultra there is a gib volume that's mounted from GKE node. The gib volume provides configs needed for CX7 and some binaries:
"NCCL_SHIMNET_GUEST_CONFIG_CHECKER_CONFIG_FILE": (
"/usr/local/gib/configs/guest_config.txtpb"
),
"NCCL_TUNER_CONFIG_PATH": "/usr/local/gib/configs/tuner_config.txtpb",
@kelvin-zou could you do another review and Approve or Request changes? I responded to your comment but did not make any changes based on your comment. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, the change looks great overall, some small suggestions.
@@ -99,6 +99,9 @@ COPY . . | |||
|
|||
FROM base AS gpu | |||
|
|||
# Needed for NVIDIA CX7 based RDMA (not cloud specific). | |||
RUN apt-get update && apt-get install -y ibverbs-utils |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
axlearn/cloud/gcp/jobset_utils.py
Outdated
# NCCL flags needed | ||
env_vars.update( | ||
{ | ||
"JAX_ENABLE_PGLE": "True", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's turn off PGLE, PGLE has some accuracy issues which we haven't been able to triage.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed, but will impact performance.
Our Jax GPU team would love to improve that. I do understand that you probably don't have bandwidth to triage right now. Let me know if you feel there is an easy way for me to reproduce in AXLearn and I Can have someone take that on.
axlearn/cloud/gcp/jobset_utils.py
Outdated
# This is needed for flash attention + auto PGLE to work | ||
"JAX_REMOVE_CUSTOM_PARTITIONING_PTR_FROM_CACHE_KEY": "True", | ||
"CUDA_DEVICE_MAX_CONNECTIONS": "1", | ||
"NVTE_FUSED_ATTN": "1", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not useful, we are not using TE, instead cudnn which doens't require this flag.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed. Great catch! This was coming from Maxtext which I believe was using TE. I was unaware that this only impacted TE.
"NCCL_SOCKET_IFNAME": "=eth0,eth1", | ||
"NCCL_CROSS_NIC": "0", | ||
"NCCL_NET_GDR_LEVEL": "PIX", | ||
"NCCL_P2P_NET_CHUNKSIZE": "65536", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where did all TCPx flags go?
Also maybe nice if we can break it into 3 parts
- nccl flags
- tcpx flags
- mlx flags.
Fwiw, this is what we do internally, we also have EFA flags, and we structure them in blocks so we can compose them
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tcpx flags are only relevant on A3 High and they're still in that class. On A3 Mega it's yet another set of NCCL TCPXO flags (note that A3 Mega is not supported in AXLearn today). I have a separate branch for a3 mega still: main...samos123:axlearn:a3-mega-support
On A3 Ultra and A4 we also use different NCCL flags. It looks like for different accelerators we're recommending different flags as the default. There is less of a difference between A3 Ultra and A4 however A4X is yet another big change because now we also have to deal with Arm and lower amount of GPUs per node.
On GCP right now it's better to have per machine type classes because they are turning out to be so different.
I do like the idea of using recomposable NCCL flags. Some of the XLA and NCCL flags could be re-usable potentially between GPU accelerators. I suggest we make that change as part of the A4 PR so we have a better view of how different the flags are. There will be an A4 PR very soon :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kelvin-zou I am fine to make it in this PR too if you feel strongly about it. I do see the value in it, just prefer deferring it to A4 work since that's basing off this PR right now.
We would really like to get this merged since main keeps on having significant changes to these files. This is critical for Google to be able to run AXLearn benchmarks. Can you please confirm whether you're fine with the changes overall so I can rebase and handle the conflicts? We have an A4/B200 PR that will be submitted soon after this PR. |
Closing this in favor for #1097 which adds A3 Ultra, A4 and A3 Mega support. |