-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
imported os and added ckpt scripts #41
base: accuracy_workstream_trn
Are you sure you want to change the base?
Conversation
Tested it locally using: pytest seed_test.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Most of the stuff we need is missing.
- Most features are missing (setting the seed, parameterizing fuji.py, all features from GPU)
- We are only have GPU training scripts for TRN repo? Where is the TRN script?
- Likely completely untested. Did we run a TRN job?
|
||
# export JAX_PLATFORMS=cpu | ||
|
||
#Perf Tuning Guideline here : https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta/docs/PGLE.md |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why are GPU flags in TRN runs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just the checkpointing script, I did not do any cleanup here.
###export NCCL_DEBUG_SUBSYS=COLL | ||
|
||
#HAH quick fix | ||
export XLA_FLAGS="--xla_dump_hlo_as_text --xla_dump_to=${HLO_DUMP_PATH} --xla_dump_hlo_pass_re='.*' --xla_dump_hlo_as_proto --xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_while_loop_double_buffering=true --xla_gpu_enable_pipelined_all_gather=true --xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true --xla_gpu_multi_streamed_windowed_einsum=true --xla_gpu_enable_custom_fusions=true" # --xla_gpu_enable_address_computation_fusion=true" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why are GPU flags in TRN runs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just the checkpointing script.
@HahTK I need a walkthrough of this script to actually clean it up. I havent used this to launch Trn jobs, I was using Apoorv's launch script.
echo "ERROR : ${TEST_SETUP} for ${N_EXPECTED_NODES} was launched with ${num_nodes}" | ||
exit 1 | ||
fi | ||
MESH_SELECTOR="gpu-${num_nodes}node-baseline" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
again everything is GPU here
@@ -0,0 +1,150 @@ | |||
#!/usr/bin/env bash |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this needs to be fixed. This is a GPU script used to run TRN. We need to use the TRN script and just add ckpt resume to it.
import importlib | ||
|
||
|
||
class SeedTest(test_utils.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where do we actually set the seed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The seed has to be set as an environment variable from any launch script like this:
export DATA_SEED=42
The launch script has not been added to the PR.
No description provided.