Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lighteval support after checkpoint, UX refactor #222

Open
wants to merge 45 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
5eb2549
adding slurm config as an argument to better generate slurm for launc…
eliebak Aug 23, 2024
3875c60
working version of lighteval after s3 on 1 node
eliebak Aug 24, 2024
e609d1c
add first version of launcher (still ugly)
eliebak Aug 24, 2024
a5c0cc2
not yet functional, lighteval stuff to figure out
eliebak Aug 25, 2024
6b58c25
remove torch.compile() bc it's not working (might be a me pb)
eliebak Aug 25, 2024
34b50a6
update launcher.py
eliebak Aug 25, 2024
7a105be
fancy launcher
eliebak Aug 26, 2024
28770a5
fancy++ launcher
eliebak Aug 26, 2024
7652089
add the possibility to override config yeahhh
eliebak Aug 26, 2024
4e2d7d9
don't run lighteval runner if no s3 uploader AND slurm (might want ch…
eliebak Aug 26, 2024
e7f0437
add CUDA__DEVICE_MAX_CONNECTIONS=1 in interactive mode
eliebak Aug 26, 2024
bb45352
add create_config, moove log_path to general
eliebak Aug 26, 2024
e9d4a2e
fix launcher and create_config file, still need some improvement for …
eliebak Aug 27, 2024
79ae2cb
lot of changes, working on 1 node with s3, will test the rest soon it…
eliebak Aug 30, 2024
cfcbd70
delete the SlurmArgs and add config to be more cluster agnostic + oth…
eliebak Aug 31, 2024
0d43a95
update wandb restart logic + logging the id and project to pass it to…
eliebak Sep 1, 2024
207797e
better wandb loggin, s3upload only for dl ckpt, correct Path and xPat…
eliebak Sep 1, 2024
8ce8b18
fix some bug with the slurm related stuff
eliebak Sep 2, 2024
6dd81b2
add back slurm
eliebak Sep 2, 2024
4750736
fix some stuff + introduce --base-config
eliebak Sep 2, 2024
90860f5
fix the computation calculation by adding GQA and layer norm at diff…
eliebak Sep 3, 2024
28b3847
change the localisation of get_llama_param_count()
eliebak Sep 3, 2024
157c2ae
change G to B i think it's better
eliebak Sep 3, 2024
6daa717
last fix
eliebak Sep 3, 2024
17bfd5f
create_config is smollm-135M toy example
eliebak Sep 3, 2024
43728d5
last fix
eliebak Sep 3, 2024
b646980
Merge branch 'main' into add-lighteval-after-ckpt
eliebak Sep 3, 2024
714644d
update test and flavours
eliebak Sep 4, 2024
930add6
forgot datasets
eliebak Sep 4, 2024
fd21322
fix wandb import
eliebak Sep 4, 2024
03e0e82
no need to modify this
eliebak Sep 4, 2024
4acf9bc
Merge branch 'main' into add-lighteval-after-ckpt
eliebak Sep 10, 2024
ab1e3c9
remove debugging print
eliebak Sep 10, 2024
7649815
change the lighteval path to the main repo
eliebak Sep 10, 2024
065d9b1
fix the interactive cases if we request less gpus than available
eliebak Sep 10, 2024
a7804f5
remove the base-configs args
eliebak Sep 10, 2024
efce15b
fix bs and gbs
eliebak Sep 10, 2024
73da086
fix the logs structure
eliebak Sep 10, 2024
67115a5
remove layer norm flops
eliebak Sep 13, 2024
6249264
forget comma
eliebak Sep 13, 2024
11d60c8
put the comma in the right place
eliebak Sep 13, 2024
5e8361c
adapt it to the current lighteval main
eliebak Sep 20, 2024
43c833f
remove print
eliebak Sep 20, 2024
3d7c98f
change after review
eliebak Sep 24, 2024
e74ffd1
uncomment logging item
eliebak Sep 26, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
237 changes: 237 additions & 0 deletions create_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
import argparse
import math
from datetime import datetime
from pathlib import Path

import torch
from nanotron.config import (
AdamWOptimizerArgs,
CheckpointsArgs,
Config,
DataArgs,
DatasetStageArgs,
GeneralArgs,
LoggingArgs,
LRSchedulerArgs,
ModelArgs,
OptimizerArgs,
ParallelismArgs,
PretrainDatasetsArgs,
RandomInit,
TokenizerArgs,
TokensArgs,
)
from nanotron.models.llama import LlamaConfig

if __name__ == "__main__":
###########################################
## ADAPT TO YOUR ENVIRONMENT (toy example of smollm-135M on 1 GPU)

HF_USER_OR_ORG = None
TRAIN_STEPS = 100
CHECKPOINT_INTERVAL = 200
SAVE_NAME = "smollm-135M-1gpu-toy"

###########################################

parser = argparse.ArgumentParser()
parser.add_argument("--save-path", help="path to save the configuration file", type=str, default="yaml")
parser.add_argument("--seed", help="seed", type=int, default=8)
args = parser.parse_args()

timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
general = GeneralArgs(
project="smollm",
run="toy-smollm",
seed=args.seed,
temp_dir="temp",
)

model_config = LlamaConfig(
bos_token_id=0,
eos_token_id=0,
hidden_act="silu",
hidden_size=576,
initializer_range=0.02,
intermediate_size=1536,
max_position_embeddings=2048,
num_attention_heads=9,
num_hidden_layers=30,
num_key_value_heads=3,
pretraining_tp=1,
rms_norm_eps=1e-05,
rope_scaling=None,
tie_word_embeddings=True,
use_cache=True,
vocab_size=49152,
)

# Uncomment to evaluate the model on a set of tasks with lighteval during the training.
# lighteval = LightEvalConfig(
# tasks=LightEvalTasksArgs(
# tasks="early-signal", # "generatives", "all"
# custom_tasks="nanotron.lighteval.evaluation_tasks",
# max_samples=1000,
# dataset_loading_processes=8,
# ),
# parallelism=ParallelismArgs(
# dp=8,
# pp=1,
# tp=1,
# pp_engine="1f1b",
# tp_mode="ALL_REDUCE",
# # recompute_granularity="selective",
# tp_linear_async_communication=False,
# ),
# batch_size=16,
# logging=LightEvalLoggingArgs(
# output_dir=None,
# push_to_hub=True,
# push_to_tensorboard=True,
# public_run=False,
# results_org=HF_USER_OR_ORG,
# tensorboard_metric_prefix="eval",
# ),
# )

lighteval = None

checkpoints = CheckpointsArgs(
# checkpoints_path="checkpoints",
checkpoints_path_is_shared_file_system=False,
# resume_checkpoint_path="local_path/to/checkpoint" or s3_path,
checkpoint_interval=CHECKPOINT_INTERVAL,
save_initial_state=False,
)

parallelism = ParallelismArgs(
dp=1,
pp=1,
tp=1,
pp_engine="1f1b",
tp_mode="REDUCE_SCATTER",
tp_linear_async_communication=True,
)

tokens = TokensArgs(
batch_accumulation_per_replica=1,
micro_batch_size=8,
sequence_length=2048,
train_steps=TRAIN_STEPS,
val_check_interval=-1,
)

model = ModelArgs(
model_config=model_config,
init_method=RandomInit(
std=1 / math.sqrt(model_config.hidden_size),
),
dtype=torch.bfloat16,
)

logging = LoggingArgs(
# 'debug', 'info', 'warning', 'error', 'critical' and 'passive'
log_level="info",
log_level_replica="info",
iteration_step_info_interval=1,
)

learning_rate_scheduler = LRSchedulerArgs(
learning_rate=3e-3,
lr_warmup_steps=10,
lr_warmup_style="linear",
lr_decay_style="linear",
lr_decay_steps=20,
lr_decay_starting_step=80,
min_decay_lr=0,
)

optimizer = OptimizerArgs(
zero_stage=0,
weight_decay=0.01,
clip_grad=1.0,
accumulate_grad_in_fp32=True,
learning_rate_scheduler=learning_rate_scheduler,
optimizer_factory=AdamWOptimizerArgs(
adam_eps=1e-08,
adam_beta1=0.9,
adam_beta2=0.95,
torch_adam_is_fused=True,
),
)

tokenizer = TokenizerArgs(
tokenizer_name_or_path="HuggingFaceTB/cosmo2-tokenizer",
)

# Uncomment if you want to upload the checkpoints to s3 or load a ckpt from s3
# s3_upload = S3UploadArgs(
# upload_s3_path=f"S3_PATH",
# remove_after_upload=True,
# s5cmd_numworkers=16,
# s5cmd_concurrency=5,
# s5cmd_path="PATH_TO_S5CMD",
# )

data_stages = [
DatasetStageArgs(
data=DataArgs(
# 1. Un-tokenized dataset from HuggingFace
dataset=PretrainDatasetsArgs(
hf_dataset_or_datasets="HuggingFaceTB/smollm-corpus", # feel free to replace it by a smaller one if you don't have enough memory
hf_dataset_splits="train",
hf_dataset_config_name="cosmopedia-v2",
text_column_name="text",
),
# 2. Pre-tokenized local dataset with Nanoset
# dataset=NanosetDatasetsArgs(
# dataset_folder="datasets/cosmopedia-v2",
# ),
# num_loading_workers=0,
# seed=general.seed,
),
name="training stage",
start_training_step=1,
),
# You can add a decay stage here if you want to change the data mixture
# Example (weight are arbitrary here):
# DatasetStageArgs(
# data=DataArgs(
# dataset=NanosetDatasetsArgs(
# dataset_folder={
# "datasets/fineweb-edu-dedup": 50,
# "datasets/cosmopedia-v2": 30,
# "datasets/python-edu": 10,
# "datasets/open-web-math": 10,
# }
# ),
# num_loading_workers=0,
# seed=general.seed,
# ),
# name="decay stage",
# start_training_step=optimizer.learning_rate_scheduler.lr_decay_starting_step,
# ),
]

config = Config(
general=general,
checkpoints=checkpoints,
parallelism=parallelism,
model=model,
tokenizer=tokenizer,
logging=logging,
tokens=tokens,
optimizer=optimizer,
data_stages=data_stages,
lighteval=lighteval,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not defined variable?

)

save_path = Path(args.save_path)
save_path.mkdir(parents=True, exist_ok=True)

config_path_yaml = save_path / f"{SAVE_NAME}.yaml"
config.save_as_yaml(config_path_yaml)

print(f"💾 Configuration saved in: {str(save_path)}")
print("To launch this configuration, run:")
print(f"python launcher.py --config-path configs/{str(config_path_yaml)}")
Loading
Loading