Skip to content

Commit

Permalink
last fix
Browse files Browse the repository at this point in the history
  • Loading branch information
eliebak committed Sep 3, 2024
1 parent 157c2ae commit 6daa717
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 31 deletions.
30 changes: 10 additions & 20 deletions launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,10 @@ def set_nested_attribute(obj, path, value):
args = parser.parse_args()

supported_base_configs = {
'llama-1B': "path_to_the_config",
}
"smollm-1700M-8nodes": "examples/smollm/configs/yaml/smollm-1700M-8nodes.yaml",
"smollm-360M-4nodes": "examples/smollm/configs/yaml/smollm-360M-4nodes.yaml",
"smollm-135M-4nodes": "examples/smollm/configs/yaml/smollm-135M-4nodes.yaml",
} # add your base configs here {name: path}

if args.base_config is None and args.config_path is None:
raise ValueError("Please provide a base config or a config path")
Expand All @@ -78,35 +80,23 @@ def set_nested_attribute(obj, path, value):
if config.general.logs_path is None and args.logs_path is None:
raise ValueError("Please provide a logs path")

if config.model.model_config.tie_word_embeddings ==True:
tie_word_embeddings_multiplier = 1
else:
tie_word_embeddings_multiplier = 2

num_params = human_format(
config.model.model_config.vocab_size * config.model.model_config.hidden_size * tie_word_embeddings_multiplier
+ config.model.model_config.num_hidden_layers
* (
3 * config.model.model_config.hidden_size * config.model.model_config.intermediate_size
+ 4 * config.model.model_config.hidden_size * config.model.model_config.hidden_size
)
).replace(".", "p")
# Apply overrides
config.model.model_config.get_llama_param_count()
).replace(".", ",")

if args.override:
for item in args.override:
if '=' not in item:
raise ValueError(f"Invalid override format: {item}. Use KEY=VALUE.")
key, value = item.split('=', 1)
try:
# Try to evaluate the value as a Python literal
value = eval(value)
except:
# If eval fails, treat it as a string
pass

set_nested_attribute(config, key, value)

print("Applied overrides:")
print("Applied overrides:")
for item in args.override:
print(f" {item}")

Expand All @@ -122,7 +112,7 @@ def set_nested_attribute(obj, path, value):
GBS = BS * config.parallelism.dp

total_tokens = config.tokens.train_steps * GBS
total_tokens_billions = total_tokens / 1e9
total_tokens_billions = human_format(total_tokens).replace(".", ",")

print(f"""
🏋️ Model Parameters:
Expand Down Expand Up @@ -153,7 +143,7 @@ def set_nested_attribute(obj, path, value):
print(f"""
📙 Training Configuration:
┌───────────────────────┬────────────────────────┐
│ Total Tokens │ {total_tokens_billions:>21.2f}B
│ Total Tokens │ {total_tokens_billions:>22}
│ Global Batch Size │ {GBS:>22,d}
│ Batch Size (per GPU) │ {BS:>22,d}
└───────────────────────┴────────────────────────┘
Expand Down
5 changes: 1 addition & 4 deletions slurm/eval_slurm_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,12 @@
"account": null,
"reservation": null,
"torchrun_args": {
"rdzv_backend": "etcd-v2",
"rdzv_endpoint": "etcd.hpc-cluster-hopper.hpc.internal.huggingface.tech:2379",
"rdzv_id": "$SLURM_JOB_ID",
"node_rank": "$SLURM_PROCID",
"role": "$SLURMD_NODENAME",
"max_restarts": 0,
"tee": 3
},
"hf_cache": "/fsx/elie_bakouch/.cache",
"hf_cache": "~/.cache",
"array": null,
"mem": null,
"begin": null
Expand Down
5 changes: 1 addition & 4 deletions slurm/launch_slurm_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,12 @@
"account": null,
"reservation": null,
"torchrun_args": {
"rdzv_backend": "etcd-v2",
"rdzv_endpoint": "etcd.hpc-cluster-hopper.hpc.internal.huggingface.tech:2379",
"rdzv_id": "$SLURM_JOB_ID",
"node_rank": "$SLURM_PROCID",
"role": "$SLURMD_NODENAME",
"max_restarts": 0,
"tee": 3
},
"hf_cache": "/fsx/elie_bakouch/.cache",
"hf_cache": "~/.cache",
"array": null,
"mem": null,
"begin": null
Expand Down
2 changes: 0 additions & 2 deletions slurm/launch_training.slurm.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,6 @@ export HF_DATASETS_CACHE={{ hf_cache }}
export HF_MODULES_CACHE={{ hf_cache }}
export HF_HOME={{ hf_cache }}

module load cuda/12.1

echo go $COUNT_NODE
echo $HOSTNAMES

Expand Down
2 changes: 1 addition & 1 deletion src/nanotron/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ class GeneralArgs:

project: str
run: Optional[str] = None
logs_path: Optional[str] = "./logs"
logs_path: Optional[str] = "logs"
launch_slurm_config: Optional[str] = None
eval_slurm_config: Optional[str] = None
timestamp_with_run: Optional[str] = None
Expand Down

0 comments on commit 6daa717

Please sign in to comment.