Skip to content

Commit

Permalink
High LR config for the 1B
Browse files Browse the repository at this point in the history
  • Loading branch information
dirkgr committed Nov 5, 2024
1 parent f33abac commit cdf4319
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 0 deletions.
37 changes: 37 additions & 0 deletions scripts/augusta/beaker/peteish1-highlr-launch.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#!/usr/bin/env bash

set -ex

NUM_NODES=$1
shift

gantry run \
--workspace ai2/13B \
--task-name peteish1-highlr \
--description "Peteish1 HighLR" \
--priority high \
--preemptible \
--beaker-image michalg/cuda11.8-ubuntu20.04-arb \
--cluster ai2/augusta-google-1 \
--gpus 8 \
--replicas "${NUM_NODES}" \
--leader-selection \
--host-networking \
--budget ai2/oe-training \
--no-nfs \
--propagate-failure \
--propagate-preemption \
--synchronized-start-timeout 15m \
--no-python \
--env LOG_FILTER_TYPE=local_rank0_only \
--env OMP_NUM_THREADS=8 \
--env OLMO_TASK=model \
--env-secret WANDB_API_KEY=DIRKG_WANDB_API_KEY \
--env-secret AWS_ACCESS_KEY_ID=DIRKG_AWS_ACCESS_KEY_ID \
--env-secret AWS_SECRET_ACCESS_KEY=DIRKG_AWS_SECRET_ACCESS_KEY \
--shared-memory 10GiB \
--yes \
--timeout=-1 \
--allow-dirty \
--retries 10 \
-- /bin/bash -c "scripts/augusta/beaker/peteish1-highlr.sh \$BEAKER_LEADER_REPLICA_HOSTNAME \$BEAKER_REPLICA_RANK"
87 changes: 87 additions & 0 deletions scripts/augusta/beaker/peteish1-highlr.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
#!/usr/bin/env bash

set -exuo pipefail
IFS=$'\n\t'

BEAKER_LEADER_REPLICA_HOSTNAME=$1
shift

BEAKER_REPLICA_RANK=$1
shift

# augusta specific environment
export LD_LIBRARY_PATH="/var/lib/tcpxo/lib64:${LD_LIBRARY_PATH}"
export NCCL_CROSS_NIC=0
export NCCL_ALGO=Ring,Tree
export NCCL_PROTO=Simple
export NCCL_MIN_NCHANNELS=4
export NCCL_P2P_NET_CHUNKSIZE=524288
export NCCL_P2P_PCI_CHUNKSIZE=524288
export NCCL_P2P_NVL_CHUNKSIZE=1048576
export NCCL_FASTRAK_NUM_FLOWS=2
export NCCL_FASTRAK_ENABLE_CONTROL_CHANNEL=0
export NCCL_BUFFSIZE=8388608
export NCCL_FASTRAK_USE_SNAP=1
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export NCCL_NET_GDR_LEVEL=PIX
export NCCL_FASTRAK_ENABLE_HOTPATH_LOGGING=0
export NCCL_TUNER_PLUGIN=libnccl-tuner.so
export NCCL_TUNER_CONFIG_PATH=/var/lib/tcpxo/lib64/a3plus_tuner_config.textproto
export NCCL_SHIMNET_GUEST_CONFIG_CHECKER_CONFIG_FILE=/var/lib/tcpxo/lib64/a3plus_guest_config.textproto
export NCCL_FASTRAK_PLUGIN_ACCEPT_TIMEOUT_MS=600000
export NCCL_NVLS_ENABLE=0
export NCCL_DEBUG=WARN
export NCCL_FASTRAK_CTRL_DEV=enp0s12
export NCCL_FASTRAK_IFNAME=enp6s0,enp7s0,enp13s0,enp14s0,enp134s0,enp135s0,enp141s0,enp142s0
export NCCL_SOCKET_IFNAME=enp0s12
export NCCL_USE_SNAP=1
export NCCL_FASTRAK_USE_LLCM=1
export NCCL_FASTRAK_LLCM_DEVICE_DIRECTORY=/dev/aperture_devices

# Install flash-attn
#conda install -y pytorch-cuda==12.4 packaging ninja cccl cuda-nvcc libcusolver-dev cuda-profiler-api libcusparse-dev libcublas-dev -c pytorch -c nvidia
#pip install flash-attn==2.5.9.post1 --no-build-isolation
pip install '.[train]'
pip freeze

# Force processes to synchronize at init_process_group
export TORCH_DIST_INIT_BARRIER=1
# Better error handling from Python
export PYTHONFAULTHANDLER=1

NAME=${GANTRY_TASK_NAME// /_}
RUN_NAME=$NAME-$(date -u +"%Y%m%d_%H%M%S")
SAVE_FOLDER=/data/$RUN_NAME
mkdir -p $SAVE_FOLDER

torchrun \
--nnodes "${BEAKER_REPLICA_COUNT}:${BEAKER_REPLICA_COUNT}" \
--nproc-per-node 8 \
--rdzv_id 12348 \
--rdzv_backend static \
--rdzv_endpoint "${BEAKER_LEADER_REPLICA_HOSTNAME}:29400" \
--node_rank "${BEAKER_REPLICA_RANK}" \
--rdzv_conf 'read_timeout=420' \
scripts/train.py \
configs/peteish1-google.yaml \
--run_name=$RUN_NAME \
--wandb.group=$NAME \
--optimizer.learning_rate=12.0e-4 \
--save_interval_ephemeral=10000 \
--eval_interval=10000 \
--fsdp.sharding_strategy=HYBRID_SHARD \
--fsdp.hybrid_sharding_num_model_replicas="${BEAKER_REPLICA_COUNT}" \
--fsdp.wrapping_strategy=by_block_and_size \
--save_folder=$SAVE_FOLDER \
--remote_save_folder="gs://ai2-llm/checkpoints/OLMo-medium/$NAME/" \
--try_load_latest_save \
--save_overwrite \
--sharded_checkpointer=olmo_core \
--device_train_microbatch_size=4 \
--device_eval_batch_size=8 \
--compile.fullgraph=false \
--fused_loss=false \
--model.flash_attention=false \
--data.num_workers=32 \
--optimizer.metrics_log_interval=10 \
--data.prefetch_factor=8

0 comments on commit cdf4319

Please sign in to comment.