forked from meta-llama/llama-recipes
-
Notifications
You must be signed in to change notification settings - Fork 1
/
train_utils.py
598 lines (516 loc) · 27 KB
/
train_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import os
import time
import yaml
from contextlib import nullcontext
from pathlib import Path
from pkg_resources import packaging
import contextlib
import gc
from datetime import datetime
import torch
import torch.cuda.nccl as nccl
import torch.distributed as dist
from torch.distributed.fsdp import StateDictType
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from tqdm import tqdm
from transformers import LlamaTokenizer
import json
from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint
from llama_recipes.policies import fpSixteen,bfSixteen, get_llama_wrapper
from llama_recipes.utils.memory_utils import MemoryTrace
from llama_recipes.utils.tflop_counter import FlopCounterMode
@contextlib.contextmanager
def maybe_run_profiler(cfg, *args, **kwargs):
use_profiler: bool = cfg.profiler
if use_profiler:
print(f"profiling is activated and results will be saved in {cfg.profile_output_dir}")
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(wait=1, warmup=2, active=3, repeat=1),
on_trace_ready=torch.profiler.tensorboard_trace_handler(
cfg.profile_output_dir
),
profile_memory=True,
with_stack=False,
record_shapes=True,
) as torch_profiler:
yield torch_profiler
else:
torch_profiler = contextlib.nullcontext()
yield None
def get_total_flops(model):
return (sum([v for _, v in model.flop_counts["Global"].items()]))
from accelerate.utils import is_xpu_available, is_ccl_available
def set_tokenizer_params(tokenizer: LlamaTokenizer):
tokenizer.pad_token_id = 0
tokenizer.padding_side = "left"
# Converting Bytes to Megabytes
def byte2mb(x):
return int(x / 2**20)
def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None):
"""
Trains the model on the given dataloader
Args:
model: The model to be trained
train_dataloader: The dataloader containing the training data
optimizer: The optimizer used for training
lr_scheduler: The learning rate scheduler
gradient_accumulation_steps: The number of steps to accumulate gradients before performing a backward/update operation
num_epochs: The number of epochs to train for
local_rank: The rank of the current node in a distributed setting
train_config: The training configuration
eval_dataloader: The dataloader containing the eval data
tokenizer: tokenizer used in the eval for decoding the predicitons
Returns: results dictionary containing average training and validation perplexity and loss
"""
# Create a gradient scaler for fp16
if train_config.use_fp16 and train_config.enable_fsdp:
scaler = ShardedGradScaler()
elif train_config.use_fp16 and not train_config.enable_fsdp:
scaler = torch.cuda.amp.GradScaler()
if train_config.enable_fsdp:
world_size = int(os.environ["WORLD_SIZE"])
autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext
train_prep = []
train_loss = []
val_prep = []
val_loss =[]
if train_config.save_metrics:
metrics_filename = f"{train_config.output_dir}/metrics_data_{local_rank}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.json"
train_step_perplexity = []
train_step_loss = []
val_step_loss = []
val_step_perplexity = []
epoch_times = []
checkpoint_times = []
results = {}
best_val_loss = float("inf")
for epoch in range(train_config.num_epochs):
epoch_start_time = time.perf_counter()
with MemoryTrace() as memtrace: # track the memory usage
model.train()
total_loss = 0.0
total_length = len(train_dataloader)//gradient_accumulation_steps
pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True)
with maybe_run_profiler(train_config) as torch_profiler:
for step, batch in enumerate(train_dataloader):
gc.collect(1)
for key in batch.keys():
if train_config.enable_fsdp:
batch[key] = batch[key].to(local_rank)
else:
batch[key] = batch[key].to('cuda:0')
flop_check_done = False
if train_config.flop_counter and step == 3 and not flop_check_done:
flop_counter = FlopCounterMode(rank=local_rank)
with flop_counter:
loss = model(**batch).loss
loss = loss / gradient_accumulation_steps
total_loss += loss.detach().float()
if train_config.use_fp16:
# if fp16 is enabled, use gradient scaler to handle gradient update
scaler.scale(loss).backward()
if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
pbar.update(1)
else:
# regular backpropagation when fp16 is not used
loss.backward()
TFlops = get_total_flops(flop_counter) / 1e12
flop_check_done = True
if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
optimizer.step()
optimizer.zero_grad()
pbar.update(1)
else:
loss = model(**batch).loss
loss = loss / gradient_accumulation_steps
total_loss += loss.detach().float()
if train_config.use_fp16:
# if fp16 is enabled, use gradient scaler to handle gradient update
scaler.scale(loss).backward()
if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
pbar.update(1)
else:
# regular backpropagation when fp16 is not used
loss.backward()
if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
optimizer.step()
optimizer.zero_grad()
pbar.update(1)
pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
pbar.close()
for step, batch in enumerate(train_dataloader):
for key in batch.keys():
if train_config.enable_fsdp:
if is_xpu_available():
batch[key] = batch[key].to(torch.device(f"xpu:{local_rank}"))
else:
batch[key] = batch[key].to(local_rank)
else:
if is_xpu_available():
batch[key] = batch[key].to('xpu:0')
else:
batch[key] = batch[key].to('cuda:0')
with autocast():
loss = model(**batch).loss
loss = loss / gradient_accumulation_steps
if train_config.save_metrics:
train_step_loss.append(loss.detach().float().item())
train_step_perplexity.append(float(torch.exp(loss.detach().float())))
total_loss += loss.detach().float()
if train_config.use_fp16:
# if fp16 is enabled, use gradient scaler to handle gradient update
scaler.scale(loss).backward()
if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0:
scaler.unscale_(optimizer)
if train_config.enable_fsdp:
model.clip_grad_norm_(train_config.gradient_clipping_threshold)
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping_threshold)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
pbar.update(1)
else:
# regular backpropagation when fp16 is not used
loss.backward()
if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0:
if train_config.enable_fsdp:
model.clip_grad_norm_(train_config.gradient_clipping_threshold)
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping_threshold)
optimizer.step()
optimizer.zero_grad()
pbar.update(1)
pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
if train_config.save_metrics:
save_to_json(metrics_filename, train_step_loss, train_loss, train_step_perplexity, train_prep, val_step_loss, val_loss, val_step_perplexity, val_prep)
pbar.close()
epoch_end_time = time.perf_counter()-epoch_start_time
epoch_times.append(epoch_end_time)
# Reducing total_loss across all devices if there's more than one CUDA device
if is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp):
dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
elif torch.cuda.device_count() > 1 and train_config.enable_fsdp:
dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
train_epoch_loss = total_loss / len(train_dataloader)
if train_config.enable_fsdp:
train_epoch_loss = train_epoch_loss/world_size
train_perplexity = torch.exp(train_epoch_loss)
train_prep.append(float(train_perplexity))
train_loss.append(float(train_epoch_loss))
if train_config.enable_fsdp:
if rank==0:
if is_xpu_available():
print(f"Max XPU memory allocated was {memtrace.peak} GB")
print(f"Max XPU memory reserved was {memtrace.max_reserved} GB")
print(f"Peak active XPU memory was {memtrace.peak_active_gb} GB")
print(f"Xpu Malloc retires : {memtrace.xpu_malloc_retires}")
else:
print(f"Max CUDA memory allocated was {memtrace.peak} GB")
print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
else:
if is_xpu_available():
print(f"Max XPU memory allocated was {memtrace.peak} GB")
print(f"Max XPU memory reserved was {memtrace.max_reserved} GB")
print(f"Peak active XPU memory was {memtrace.peak_active_gb} GB")
print(f"Xpu Malloc retires : {memtrace.xpu_malloc_retires}")
else:
print(f"Max CUDA memory allocated was {memtrace.peak} GB")
print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
# Update the learning rate as needed
lr_scheduler.step()
if train_config.run_validation:
eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer)
if train_config.save_metrics:
val_step_loss.extend(temp_val_loss)
val_step_perplexity.extend(temp_step_perplexity)
checkpoint_start_time = time.perf_counter()
if train_config.save_model and eval_epoch_loss < best_val_loss:
if train_config.enable_fsdp:
dist.barrier()
if train_config.use_peft:
if train_config.enable_fsdp:
if rank==0:
print(f"we are about to save the PEFT modules")
else:
print(f"we are about to save the PEFT modules")
model.save_pretrained(train_config.output_dir)
if train_config.enable_fsdp:
if rank==0:
print(f"PEFT modules are saved in {train_config.output_dir} directory")
else:
print(f"PEFT modules are saved in {train_config.output_dir} directory")
else:
if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
save_model_checkpoint(
model, optimizer, rank, train_config, epoch=epoch
)
elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
print("=====================================================")
save_model_and_optimizer_sharded(model, rank, train_config)
if train_config.save_optimizer:
save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
print(" Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT")
print("=====================================================")
if not train_config.use_peft and train_config.save_optimizer:
save_optimizer_checkpoint(
model, optimizer, rank, train_config, epoch=epoch
)
print(" Saving the FSDP model checkpoints and optimizer using FULL_STATE_DICT")
print("=====================================================")
if train_config.enable_fsdp:
dist.barrier()
checkpoint_end_time = time.perf_counter() - checkpoint_start_time
checkpoint_times.append(checkpoint_end_time)
if eval_epoch_loss < best_val_loss:
best_val_loss = eval_epoch_loss
if train_config.enable_fsdp:
if rank==0:
print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
else:
print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
val_loss.append(float(best_val_loss))
val_prep.append(float(eval_ppl))
if train_config.enable_fsdp:
if rank==0:
print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
else:
print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
# Saving the results every epoch to plot later
if train_config.save_metrics:
save_to_json(metrics_filename, train_step_loss, train_loss, train_step_perplexity, train_prep, val_step_loss, val_loss, val_step_perplexity, val_prep)
avg_epoch_time = sum(epoch_times)/ len(epoch_times)
avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times) if len(checkpoint_times) > 0 else 0
avg_train_prep = sum(train_prep)/len(train_prep)
avg_train_loss = sum(train_loss)/len(train_loss)
if train_config.run_validation:
avg_eval_prep = sum(val_prep)/len(val_prep)
avg_eval_loss = sum(val_loss)/len(val_loss)
results['avg_train_prep'] = avg_train_prep
results['avg_train_loss'] = avg_train_loss
if train_config.run_validation:
results['avg_eval_prep'] = avg_eval_prep
results['avg_eval_loss'] = avg_eval_loss
results["avg_epoch_time"] = avg_epoch_time
results["avg_checkpoint_time"] = avg_checkpoint_time
if train_config.flop_counter:
results["model_flops"]= TFlops
if train_config.save_metrics:
results["metrics_filename"] = metrics_filename
#saving the training params including fsdp setting for reference.
if train_config.enable_fsdp and not train_config.use_peft:
save_train_params(train_config, fsdp_config, rank)
return results
def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
"""
Evaluates the model on the given dataloader
Args:
model: The model to evaluate
eval_dataloader: The dataloader containing the evaluation data
local_rank: The rank of the current node in a distributed setting
tokenizer: The tokenizer used to decode predictions
Returns: eval_ppl, eval_epoch_loss
"""
if train_config.enable_fsdp:
world_size = int(os.environ["WORLD_SIZE"])
model.eval()
eval_preds = []
val_step_loss = []
val_step_perplexity = []
eval_loss = 0.0 # Initialize evaluation loss
with MemoryTrace() as memtrace:
for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch", dynamic_ncols=True)):
gc.collect(1)
for key in batch.keys():
if train_config.enable_fsdp:
batch[key] = batch[key].to(local_rank)
else:
if is_xpu_available():
batch[key] = batch[key].to('xpu:0')
else:
batch[key] = batch[key].to('cuda:0')
# Ensure no gradients are computed for this scope to save memory
with torch.no_grad():
# Forward pass and compute loss
outputs = model(**batch)
loss = outputs.loss
if train_config.save_metrics:
val_step_loss.append(loss.detach().float().item())
val_step_perplexity.append(float(torch.exp(loss.detach().float())))
eval_loss += loss.detach().float()
# Decode predictions and add to evaluation predictions list
preds = torch.argmax(outputs.logits, -1)
eval_preds.extend(
tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True)
)
# If there's more than one CUDA device, reduce evaluation loss across all devices
if is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp):
dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
# Compute average loss and perplexity
eval_epoch_loss = eval_loss / len(eval_dataloader)
if train_config.enable_fsdp:
eval_epoch_loss = eval_epoch_loss/world_size
eval_ppl = torch.exp(eval_epoch_loss)
# Print evaluation metrics
if train_config.enable_fsdp:
if local_rank==0:
print(f" {eval_ppl=} {eval_epoch_loss=}")
else:
print(f" {eval_ppl=} {eval_epoch_loss=}")
return eval_ppl, eval_epoch_loss, val_step_loss, val_step_perplexity
def freeze_transformer_layers(model, num_layer):
for i, layer in enumerate(model.model.layers):
if i < num_layer:
for param in layer.parameters():
param.requires_grad = False
def check_frozen_layers_peft_model(model):
for i, layer in enumerate(model.base_model.model.model.layers):
for name, param in layer.named_parameters():
print(f"Layer {i}, parameter {name}: requires_grad = {param.requires_grad}")
def setup():
"""Initialize the process group for distributed training"""
if is_ccl_available():
# distributed training on xpus
dist.init_process_group("ccl")
else:
dist.init_process_group("nccl")
def setup_environ_flags(rank):
"""Set environment flags for debugging purposes"""
os.environ["TORCH_SHOW_CPP_STACKTRACES"] = str(1)
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = str(1)
# os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
# This flag will help with CUDA memory fragmentations that can lead into OOM in some cases.
# Note this is only availble in PyTorch Nighlies (as of July 30 2023)
# os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True'
if rank == 0:
print(f"--> Running with torch dist debug set to detail")
def cleanup():
"""Clean up the process group after training"""
dist.destroy_process_group()
def clear_gpu_cache(rank=None):
"""Clear the GPU cache for all ranks"""
if rank == 0:
print(f"Clearing GPU cache for all ranks")
if is_xpu_available():
torch.xpu_empty_cache()
else:
torch.cuda.empty_cache()
def get_parameter_dtypes(model):
"""Get the data types of model parameters"""
parameter_dtypes = {}
for name, parameter in model.named_parameters():
parameter_dtypes[name] = parameter.dtype
return parameter_dtypes
def print_model_size(model, config, rank: int = 0) -> None:
"""
Print model name, the number of trainable parameters and initialization time.
Args:
model: The PyTorch model.
model_name (str): Name of the model.
init_time_start (float): Initialization start time.
init_time_end (float): Initialization end time.
rank (int, optional): Current process's rank. Defaults to 0.
"""
if rank == 0:
print(f"--> Model {config.model_name}")
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\n--> {config.model_name} has {total_params / 1e6} Million params\n")
def get_policies(cfg, rank):
"""Get the policies for mixed precision and fsdp wrapping"""
verify_bfloat_support = ((
torch.version.cuda
and torch.cuda.is_bf16_supported()
and packaging.version.parse(torch.version.cuda).release >= (11, 0)
and dist.is_nccl_available()
and nccl.version() >= (2, 10)
) or
(is_xpu_available()))
mixed_precision_policy = None
wrapping_policy = None
# Mixed precision
if cfg.mixed_precision:
bf16_ready = verify_bfloat_support
if bf16_ready and not cfg.use_fp16:
mixed_precision_policy = bfSixteen
if rank == 0:
print(f"bFloat16 enabled for mixed precision - using bfSixteen policy")
elif cfg.use_fp16:
mixed_precision_policy = fpSixteen
if rank == 0:
print(f"FP16 enabled")
else:
print(f"bFloat16 support not present. Using FP32, and not mixed precision")
wrapping_policy = get_llama_wrapper()
return mixed_precision_policy, wrapping_policy
def save_train_params(train_config, fsdp_config, rank):
"""
This function saves the train_config and FSDP config into a train_params.yaml.
This will be used by converter script in the inference folder to fetch the HF model name or path.
It also would be hepful as a log for future references.
"""
# Convert the train_config and fsdp_config objects to dictionaries,
# converting all values to strings to ensure they can be serialized into a YAML file
train_config_dict = {k: str(v) for k, v in vars(train_config).items() if not k.startswith('__')}
fsdp_config_dict = {k: str(v) for k, v in vars(fsdp_config).items() if not k.startswith('__')}
# Merge the two dictionaries into one
train_params_dict = {**train_config_dict, **fsdp_config_dict}
# Construct the folder name (follwoing FSDP checkpointing style) using properties of the train_config object
folder_name = (
train_config.dist_checkpoint_root_folder
+ "/"
+ train_config.dist_checkpoint_folder
+ "-"
+ train_config.model_name
)
save_dir = Path.cwd() / folder_name
# If the directory does not exist, create it
if not os.path.exists(save_dir):
os.makedirs(save_dir)
# Convert the dictionary to a YAML string
config_yaml = yaml.dump(train_params_dict, indent=4)
file_name = os.path.join(save_dir,'train_params.yaml')
# Check if there's a directory with the same name as the file
if os.path.isdir(file_name):
print(f"Error: {file_name} is a directory, not a file.")
else:
# Write the YAML string to the file
with open(file_name, 'w') as f:
f.write(config_yaml)
if rank==0:
print(f"training params are saved in {file_name}")
def save_to_json(output_filename, train_step_loss, train_epoch_loss, train_step_ppl, train_epoch_ppl, val_step_loss, val_epoch_loss, val_step_ppl, val_epoch_ppl):
metrics_data = {
"train_step_loss": train_step_loss,
"train_epoch_loss": train_epoch_loss,
"train_step_perplexity": train_step_ppl,
"train_epoch_perplexity": train_epoch_ppl,
"val_step_loss": val_step_loss,
"val_epoch_loss": val_epoch_loss,
"val_step_perplexity": val_step_ppl,
"val_epoch_perplexity": val_epoch_ppl
}
with open(output_filename, "w") as f:
json.dump(metrics_data, f)