Skip to content

Commit

Permalink
create_config is smollm-135M toy example
Browse files Browse the repository at this point in the history
  • Loading branch information
eliebak committed Sep 3, 2024
1 parent 6daa717 commit 17bfd5f
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 55 deletions.
140 changes: 85 additions & 55 deletions create_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,20 @@
)

if __name__ == "__main__":
###########################################
## ADAPT TO YOUR ENVIRONMENT (toy example of smollm-135M on 1 GPU)

HF_USER_OR_ORG = "eliebak"
TRAIN_STEPS = 100
CHECKPOINT_INTERVAL = 200
SAVE_NAME="smollm-135M-1gpu-toy"


###########################################

parser = argparse.ArgumentParser()
parser.add_argument("--save-path", help="path to save the configuration file", type=str, required=True)
parser.add_argument("--save-path", help="path to save the configuration file", type=str, default="yaml")
parser.add_argument("--seed", help="seed", type=int, default=8)
parser.add_argument("--priority", "--qos", "-p", help="qos to use", type=str, default="high")
parser.add_argument("--override", nargs="+", metavar="KEY=VALUE",
help="Override config values. Use dot notation for nested keys.")
parser.add_argument("--launch", action="store_true", help="Launch the configuration immediately")
parser.add_argument("--run", help="name of the run", type=str)
parser.add_argument("--logs-path", help="path to the logs folder", type=str)
Expand All @@ -48,8 +56,9 @@
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
general = GeneralArgs(
project="smollm",
run="toy-smollm",
seed=args.seed,
temp_dir="/scratch",
temp_dir="temp",
)

model_config = LlamaConfig(
Expand All @@ -68,55 +77,56 @@
rope_scaling=None,
tie_word_embeddings=True,
use_cache=True,
vocab_size=49152,
vocab_size=49152,
)

lighteval = LightEvalConfig(
tasks=LightEvalTasksArgs(
tasks="early-signal", # "generatives", "all"
custom_tasks="nanotron.lighteval.evaluation_tasks",
max_samples=1000, # Cap very large evals or for debugging
dataset_loading_processes=8,
),
parallelism=ParallelismArgs(
dp=8,
pp=1,
tp=1,
pp_engine="1f1b",
tp_mode="ALL_REDUCE",
# recompute_granularity="selective",
tp_linear_async_communication=False,
),
batch_size=16,
logging=LightEvalLoggingArgs(
local_output_path=f"/fsx/elie_bakouch/refactor-lighteval-logs/{general.project}-{general.run}",
#local_output_path=PATH_TO_LOCAL_LOG,
private=True,
push_details_to_hub=True,
push_results_to_hub=True,
push_results_to_tensorboard=True,
hf_user_or_org="eliebak",
#hf_user_or_org="USER_OR_ORG",
hub_repo_results="lighteval-results",
hub_repo_details="lighteval-details",
hub_repo_tensorboard="smollm-evals-visualization",
tensorboard_metric_prefix="eval",
),
slurm_template="/fsx/elie_bakouch/nanotron/slurm/run_eval.slurm.jinja",
)

# lighteval = LightEvalConfig(
# tasks=LightEvalTasksArgs(
# tasks="early-signal", # "generatives", "all"
# custom_tasks="nanotron.lighteval.evaluation_tasks",
# max_samples=1000,
# dataset_loading_processes=8,
# ),
# parallelism=ParallelismArgs(
# dp=8,
# pp=1,
# tp=1,
# pp_engine="1f1b",
# tp_mode="ALL_REDUCE",
# # recompute_granularity="selective",
# tp_linear_async_communication=False,
# ),
# batch_size=16,
# logging=LightEvalLoggingArgs(
# local_output_path="lighteval-logs",
# private=True,
# push_details_to_hub=True,
# push_results_to_hub=True,
# push_results_to_tensorboard=True,
# hf_user_or_org=HF_USER_OR_ORG,
# hub_repo_results="lighteval-results",
# hub_repo_details="lighteval-details",
# hub_repo_tensorboard="smollm-evals-visualization",
# tensorboard_metric_prefix="eval",
# ),
# slurm_template="slurm/run_eval.slurm.jinja",
# # slurm_template="slurm/run_eval_s3.slurm.jinja", if s3

# )

lighteval = None

checkpoints = CheckpointsArgs(
checkpoints_path=f"/fsx/elie_bakouch/refactor-checkpoints/{general.project}-{general.run}",
#checkpoints_path="CHECKPOINTS_PATH",
checkpoints_path="checkpoints",
checkpoints_path_is_shared_file_system=False,
resume_checkpoint_path="/fsx/elie_bakouch/refactor-checkpoints/smollm-%date_%jobid/60",
checkpoint_interval=20,
# resume_checkpoint_path="",
checkpoint_interval=CHECKPOINT_INTERVAL,
save_initial_state=False,
)

parallelism = ParallelismArgs(
dp=8,
dp=1,
pp=1,
tp=1,
pp_engine="1f1b",
Expand All @@ -126,9 +136,9 @@

tokens = TokensArgs(
batch_accumulation_per_replica=8,
micro_batch_size=16,
micro_batch_size=8,
sequence_length=2048,
train_steps=100,
train_steps=TRAIN_STEPS,
val_check_interval=-1,
)

Expand All @@ -148,12 +158,12 @@
)

learning_rate_scheduler = LRSchedulerArgs(
learning_rate=1e-4,
learning_rate=3e-3,
lr_warmup_steps=10,
lr_warmup_style="linear",
lr_decay_style="linear",
lr_decay_style="1-sqrt",
lr_decay_steps = 20,
lr_decay_starting_step= 80,
lr_decay_starting_step=80 ,
min_decay_lr=0,
)

Expand All @@ -176,25 +186,45 @@
tokenizer_name_or_path="HuggingFaceTB/cosmo2-tokenizer",
)

# Uncomment if you want to upload the checkpoints to s3 or load a ckpt from s3
# s3_upload = S3UploadArgs(
# upload_s3_path=f"s3://elie-exp/debug_nanotron/{general.project}-{general.run}-{timestamp}",
# upload_s3_path=f"S3_PATH",
# remove_after_upload=True,
# s5cmd_numworkers=16,
# s5cmd_concurrency=5,
# s5cmd_path="/fsx/elie_bakouch/miniconda3/envs/smollm/bin/s5cmd",
# s5cmd_path="PATH_TO_S5CMD",
# )

data_stages=[
DatasetStageArgs(
data=DataArgs(
dataset=NanosetDatasetsArgs(
dataset_folder="/fsx/elie_bakouch/nanotron/datasets/cosmopedia-v2",
dataset_folder="datasets/cosmopedia-v2",
),
num_loading_workers=0,
seed=general.seed,
),
name="training stage",
start_training_step=1,
),
# You can add a decay stage here if you want to change the data mixture
# Example (weight are arbitrary here):
# DatasetStageArgs(
# data=DataArgs(
# dataset=NanosetDatasetsArgs(
# dataset_folder={
# "datasets/fineweb-edu-dedup": 50,
# "datasets/cosmopedia-v2": 30,
# "datasets/python-edu": 10,
# "datasets/open-web-math": 10,
# }
# ),
# num_loading_workers=0,
# seed=general.seed,
# ),
# name="decay stage",
# start_training_step=optimizer.learning_rate_scheduler.lr_decay_starting_step,
# ),
]

config = Config(
Expand All @@ -213,13 +243,13 @@
save_path= Path(args.save_path)
save_path.mkdir(parents=True, exist_ok=True)

config_path_yaml = save_path / f"{args.run}-{timestamp}.yaml"
config_path_yaml = save_path / f"{SAVE_NAME}.yaml"
config.save_as_yaml(config_path_yaml)

print(f"💾 Configuration saved in: {str(save_path)}")

if args.launch:
# Change the launcher_path

# Sanity check for logs_path and run
if not args.logs_path:
raise ValueError("--logs_path must be defined. Please provide a path for the logs.")
Expand Down Expand Up @@ -248,4 +278,4 @@
subprocess.run(launch_command, check=True)
else:
print("To launch this configuration, run:")
print(f"python {os.path.join(dir, 'launcher.py')} {config_path_yaml}")
print(f"python 'launcher.py' configs/{str(config_path_yaml)}")
1 change: 1 addition & 0 deletions launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def set_nested_attribute(obj, path, value):
"smollm-1700M-8nodes": "examples/smollm/configs/yaml/smollm-1700M-8nodes.yaml",
"smollm-360M-4nodes": "examples/smollm/configs/yaml/smollm-360M-4nodes.yaml",
"smollm-135M-4nodes": "examples/smollm/configs/yaml/smollm-135M-4nodes.yaml",
"smollm-135M-1gpu": "examples/smollm/configs/yaml/smollm-135M-1gpu.yaml",
} # add your base configs here {name: path}

if args.base_config is None and args.config_path is None:
Expand Down

0 comments on commit 17bfd5f

Please sign in to comment.