diff --git a/scripts/augusta/beaker/peteish1-highlr-launch.sh b/scripts/augusta/beaker/peteish1-highlr-launch.sh new file mode 100755 index 000000000..3e7375709 --- /dev/null +++ b/scripts/augusta/beaker/peteish1-highlr-launch.sh @@ -0,0 +1,37 @@ +#!/usr/bin/env bash + +set -ex + +NUM_NODES=$1 +shift + +gantry run \ + --workspace ai2/13B \ + --task-name peteish1-highlr \ + --description "Peteish1 HighLR" \ + --priority high \ + --preemptible \ + --beaker-image michalg/cuda11.8-ubuntu20.04-arb \ + --cluster ai2/augusta-google-1 \ + --gpus 8 \ + --replicas "${NUM_NODES}" \ + --leader-selection \ + --host-networking \ + --budget ai2/oe-training \ + --no-nfs \ + --propagate-failure \ + --propagate-preemption \ + --synchronized-start-timeout 15m \ + --no-python \ + --env LOG_FILTER_TYPE=local_rank0_only \ + --env OMP_NUM_THREADS=8 \ + --env OLMO_TASK=model \ + --env-secret WANDB_API_KEY=DIRKG_WANDB_API_KEY \ + --env-secret AWS_ACCESS_KEY_ID=DIRKG_AWS_ACCESS_KEY_ID \ + --env-secret AWS_SECRET_ACCESS_KEY=DIRKG_AWS_SECRET_ACCESS_KEY \ + --shared-memory 10GiB \ + --yes \ + --timeout=-1 \ + --allow-dirty \ + --retries 10 \ + -- /bin/bash -c "scripts/augusta/beaker/peteish1-highlr.sh \$BEAKER_LEADER_REPLICA_HOSTNAME \$BEAKER_REPLICA_RANK" diff --git a/scripts/augusta/beaker/peteish1-highlr.sh b/scripts/augusta/beaker/peteish1-highlr.sh new file mode 100755 index 000000000..e96d02f16 --- /dev/null +++ b/scripts/augusta/beaker/peteish1-highlr.sh @@ -0,0 +1,87 @@ +#!/usr/bin/env bash + +set -exuo pipefail +IFS=$'\n\t' + +BEAKER_LEADER_REPLICA_HOSTNAME=$1 +shift + +BEAKER_REPLICA_RANK=$1 +shift + +# augusta specific environment +export LD_LIBRARY_PATH="/var/lib/tcpxo/lib64:${LD_LIBRARY_PATH}" +export NCCL_CROSS_NIC=0 +export NCCL_ALGO=Ring,Tree +export NCCL_PROTO=Simple +export NCCL_MIN_NCHANNELS=4 +export NCCL_P2P_NET_CHUNKSIZE=524288 +export NCCL_P2P_PCI_CHUNKSIZE=524288 +export NCCL_P2P_NVL_CHUNKSIZE=1048576 +export NCCL_FASTRAK_NUM_FLOWS=2 +export NCCL_FASTRAK_ENABLE_CONTROL_CHANNEL=0 +export NCCL_BUFFSIZE=8388608 +export NCCL_FASTRAK_USE_SNAP=1 +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +export NCCL_NET_GDR_LEVEL=PIX +export NCCL_FASTRAK_ENABLE_HOTPATH_LOGGING=0 +export NCCL_TUNER_PLUGIN=libnccl-tuner.so +export NCCL_TUNER_CONFIG_PATH=/var/lib/tcpxo/lib64/a3plus_tuner_config.textproto +export NCCL_SHIMNET_GUEST_CONFIG_CHECKER_CONFIG_FILE=/var/lib/tcpxo/lib64/a3plus_guest_config.textproto +export NCCL_FASTRAK_PLUGIN_ACCEPT_TIMEOUT_MS=600000 +export NCCL_NVLS_ENABLE=0 +export NCCL_DEBUG=WARN +export NCCL_FASTRAK_CTRL_DEV=enp0s12 +export NCCL_FASTRAK_IFNAME=enp6s0,enp7s0,enp13s0,enp14s0,enp134s0,enp135s0,enp141s0,enp142s0 +export NCCL_SOCKET_IFNAME=enp0s12 +export NCCL_USE_SNAP=1 +export NCCL_FASTRAK_USE_LLCM=1 +export NCCL_FASTRAK_LLCM_DEVICE_DIRECTORY=/dev/aperture_devices + +# Install flash-attn +#conda install -y pytorch-cuda==12.4 packaging ninja cccl cuda-nvcc libcusolver-dev cuda-profiler-api libcusparse-dev libcublas-dev -c pytorch -c nvidia +#pip install flash-attn==2.5.9.post1 --no-build-isolation +pip install '.[train]' +pip freeze + +# Force processes to synchronize at init_process_group +export TORCH_DIST_INIT_BARRIER=1 +# Better error handling from Python +export PYTHONFAULTHANDLER=1 + +NAME=${GANTRY_TASK_NAME// /_} +RUN_NAME=$NAME-$(date -u +"%Y%m%d_%H%M%S") +SAVE_FOLDER=/data/$RUN_NAME +mkdir -p $SAVE_FOLDER + +torchrun \ + --nnodes "${BEAKER_REPLICA_COUNT}:${BEAKER_REPLICA_COUNT}" \ + --nproc-per-node 8 \ + --rdzv_id 12348 \ + --rdzv_backend static \ + --rdzv_endpoint "${BEAKER_LEADER_REPLICA_HOSTNAME}:29400" \ + --node_rank "${BEAKER_REPLICA_RANK}" \ + --rdzv_conf 'read_timeout=420' \ + scripts/train.py \ + configs/peteish1-google.yaml \ + --run_name=$RUN_NAME \ + --wandb.group=$NAME \ + --optimizer.learning_rate=12.0e-4 \ + --save_interval_ephemeral=10000 \ + --eval_interval=10000 \ + --fsdp.sharding_strategy=HYBRID_SHARD \ + --fsdp.hybrid_sharding_num_model_replicas="${BEAKER_REPLICA_COUNT}" \ + --fsdp.wrapping_strategy=by_block_and_size \ + --save_folder=$SAVE_FOLDER \ + --remote_save_folder="gs://ai2-llm/checkpoints/OLMo-medium/$NAME/" \ + --try_load_latest_save \ + --save_overwrite \ + --sharded_checkpointer=olmo_core \ + --device_train_microbatch_size=4 \ + --device_eval_batch_size=8 \ + --compile.fullgraph=false \ + --fused_loss=false \ + --model.flash_attention=false \ + --data.num_workers=32 \ + --optimizer.metrics_log_interval=10 \ + --data.prefetch_factor=8