Skip to content

Commit

Permalink
revert create config mamba
Browse files Browse the repository at this point in the history
  • Loading branch information
3outeille committed Apr 24, 2024
1 parent e245eb8 commit f6befa6
Showing 1 changed file with 23 additions and 105 deletions.
128 changes: 23 additions & 105 deletions examples/mamba/create_config_mamba.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
""" Example python script to generate a YAML config file which can be used to run a training with nanotron. Refer to "examples" section in the `/README.md` for more information."""
import math
import os
import pprint
import uuid

from config import MambaConfig, MambaInit, MambaModelConfig
from nanotron.config import (
AdamWOptimizerArgs,
CheckpointsArgs,
DataArgs,
DatasetStageArgs,
Expand All @@ -21,10 +21,9 @@
)
from nanotron.logging import human_format

import wandb

new_job_id = uuid.uuid4()
job_id = str(new_job_id)[:8]
seed = 42

ssm_cfg_dtype = "bfloat16"
ssm_cfg = {
Expand All @@ -44,7 +43,7 @@
# https://huggingface.co/state-spaces/mamba-790m/blob/main/config.json
model_config = MambaModelConfig(
d_model=1024,
num_hidden_layers=48,
num_hidden_layers=2,
vocab_size=50278,
ssm_cfg=ssm_cfg,
rms_norm=True,
Expand Down Expand Up @@ -95,27 +94,31 @@

seed = 42


optimizer = OptimizerArgs(
zero_stage=0,
weight_decay=0.01,
clip_grad=1.0,
accumulate_grad_in_fp32=True, # NOTE(fmom): because we are using PP=TP=DP=1
adam_eps=1e-08,
adam_beta1=0.9,
adam_beta2=0.95,
torch_adam_is_fused=True,
learning_rate_scheduler=LRSchedulerArgs(
learning_rate=0.0015,
lr_warmup_steps=30,
lr_warmup_style="linear",
lr_decay_style="cosine",
min_decay_lr=0.00015,
),
optimizer_factory=AdamWOptimizerArgs(
adam_eps=1e-08,
adam_beta1=0.9,
adam_beta2=0.95,
torch_adam_is_fused=True,
),
)


parallelism = ParallelismArgs(
dp=1,
pp=1,
dp=2,
pp=2,
tp=2,
pp_engine="1f1b",
tp_mode="ALL_REDUCE",
Expand All @@ -135,17 +138,19 @@
)
]

model = ModelArgs(
init_method=MambaInit(initializer_range=0.02, rescale_prenorm_residual=True, n_residuals_per_layer=1),
model_config=model_config,
)

checkpoints_path = os.path.dirname(os.path.dirname(__file__)) + "/checkpoints"
os.makedirs(checkpoints_path, exist_ok=True)

config = MambaConfig(
general=GeneralArgs(project="test", run="mamba", seed=seed, ignore_sanity_checks=True),
checkpoints=CheckpointsArgs(checkpoints_path=checkpoints_path, checkpoint_interval=100),
parallelism=parallelism,
model=ModelArgs(
init_method=MambaInit(initializer_range=0.02, rescale_prenorm_residual=True, n_residuals_per_layer=1),
model_config=model_config,
),
model=model,
tokenizer=TokenizerArgs("gpt2"),
optimizer=optimizer,
logging=LoggingArgs(),
Expand All @@ -155,94 +160,7 @@
)

if __name__ == "__main__":
import argparse
from dataclasses import fields, is_dataclass

from nanotron.config import get_config_from_file

def print_differences(target, updates):
if not is_dataclass(target) or not is_dataclass(updates):
raise ValueError("Both target and updates should be dataclass instances")

for field in fields(target):
update_value = getattr(updates, field.name)

if update_value is not None:
if is_dataclass(update_value):
print_differences(getattr(target, field.name), update_value)
else:
target_value = getattr(target, field.name)
if update_value != target_value:
if update_value.__class__.__module__ != "builtins":
continue
print(f"{field.name}: {target_value} -> {update_value}")

parser = argparse.ArgumentParser()
parser.add_argument("--out-dir", required=True, help="Output directory for yaml", type=str)
parser.add_argument("--wandb-username", required=True, help="Specific wandb username", type=str)
parser.add_argument("--wandb-project", required=True, help="Specific wandb project name", type=str)
parser.add_argument("--wandb-run", required=True, help="Specific name for this run", type=str)

args = parser.parse_args()

config.general.project = args.wandb_project
config.general.run = f"{args.wandb_run}_{job_id}"

api = wandb.Api()
projects = api.projects(entity=args.wandb_username)
project_exists = any(project.name == args.wandb_project for project in projects)

if not project_exists:
raise ValueError(
f"Project '{args.wandb_project}' does not exist. You should create the project first at entity {config.experiment_logger.wandb_logger.wandb_entity}"
)

directories = []

experiment_path = f"{args.out_dir}/{config.general.project}/{config.general.run}"
directories.append(experiment_path)

config.checkpoints.checkpoints_path = f"{experiment_path}/checkpoints"
config.checkpoints.resume_checkpoint_path = f"{experiment_path}/checkpoints"
directories.append(config.checkpoints.checkpoints_path)
directories.append(config.checkpoints.resume_checkpoint_path)

# if config.lighteval is not None:
# config.lighteval.slurm_script_dir = f"{experiment_path}/lighteval/slurm_scripts"
# config.lighteval.slurm_template = f"{experiment_path}/run_eval.slurm.jinja"
# config.lighteval.logging.local_output_path = f"{experiment_path}/logs"

# directories.append(config.lighteval.slurm_script_dir)
# directories.append(config.lighteval.logging.local_output_path)

# if config.s3_upload is not None:
# config.s3_upload.upload_s3_path = f"s3://huggingface-brrr-us-east-1/fmom/checkpoints/{args.wandb_run}_{job_id}"
# directories.append(config.s3_upload.upload_s3_path)

# if config.profiler is not None:
# config.profiler.profiler_export_path = f"{experiment_path}/logs"

directories.append(f"{experiment_path}/logs")

for dir_path in directories:
if not os.path.exists(dir_path):
os.makedirs(dir_path)

pprint.pprint(f"Dataset name: {config.data_stages}")
print("Parallelism")
print("\tdp", config.parallelism.dp)
print("\tpp", config.parallelism.pp)
print("\ttp", config.parallelism.tp)
if config.lighteval is not None:
print("Parallelism LightEval")
print("\tdp", config.lighteval.parallelism.dp)
print("\tpp", config.lighteval.parallelism.pp)
print("\ttp", config.lighteval.parallelism.tp)

yaml_path = f"{experiment_path}/{config.general.run}.yaml"
# Sanity check that we can load, save to YAML and reload the config
config.save_as_yaml(yaml_path)
config2 = get_config_from_file(yaml_path, config_class=MambaConfig)
print_differences(config, config2)

print("Save at", yaml_path)
dir = os.path.dirname(__file__)

# Save config as YAML file
config.save_as_yaml(f"{dir}/config_mamba.yaml")

0 comments on commit f6befa6

Please sign in to comment.