Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

imported os and added ckpt scripts #41

Open
wants to merge 4 commits into
base: accuracy_workstream_trn
Choose a base branch
from

Conversation

dgourab-aws
Copy link
Collaborator

No description provided.

@dgourab-aws
Copy link
Collaborator Author

Tested it locally using: pytest seed_test.py

Copy link
Collaborator

@HahTK HahTK left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most of the stuff we need is missing.

  1. Most features are missing (setting the seed, parameterizing fuji.py, all features from GPU)
  2. We are only have GPU training scripts for TRN repo? Where is the TRN script?
  3. Likely completely untested. Did we run a TRN job?


# export JAX_PLATFORMS=cpu

#Perf Tuning Guideline here : https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta/docs/PGLE.md
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are GPU flags in TRN runs?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just the checkpointing script, I did not do any cleanup here.

###export NCCL_DEBUG_SUBSYS=COLL

#HAH quick fix
export XLA_FLAGS="--xla_dump_hlo_as_text --xla_dump_to=${HLO_DUMP_PATH} --xla_dump_hlo_pass_re='.*' --xla_dump_hlo_as_proto --xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_while_loop_double_buffering=true --xla_gpu_enable_pipelined_all_gather=true --xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true --xla_gpu_multi_streamed_windowed_einsum=true --xla_gpu_enable_custom_fusions=true" # --xla_gpu_enable_address_computation_fusion=true"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are GPU flags in TRN runs?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just the checkpointing script.
@HahTK I need a walkthrough of this script to actually clean it up. I havent used this to launch Trn jobs, I was using Apoorv's launch script.

echo "ERROR : ${TEST_SETUP} for ${N_EXPECTED_NODES} was launched with ${num_nodes}"
exit 1
fi
MESH_SELECTOR="gpu-${num_nodes}node-baseline"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again everything is GPU here

@@ -0,0 +1,150 @@
#!/usr/bin/env bash
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this needs to be fixed. This is a GPU script used to run TRN. We need to use the TRN script and just add ckpt resume to it.

import importlib


class SeedTest(test_utils.TestCase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where do we actually set the seed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The seed has to be set as an environment variable from any launch script like this:
export DATA_SEED=42
The launch script has not been added to the PR.

@HahTK
Copy link
Collaborator

HahTK commented Dec 27, 2024

Also the branch was created from the wrong commit id. It should have been
33ec152

but it seems to be branched from this instead
c20387c

@dgourab-aws
Copy link
Collaborator Author

Also the branch was created from the wrong commit id. It should have been 33ec152

but it seems to be branched from this instead c20387c

The GPU branch was created from 33ec152, the TRN branch was to be created from the AXLearn upstream branch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants