-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathtrain.py
713 lines (619 loc) · 31 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
import argparse
import glob
import itertools
import json
import os
import shutil
import time
from contextlib import contextmanager
from datetime import datetime, timezone
import bitsandbytes
import deepspeed
import toml
import torch
import transformers
from deepspeed.runtime.pipe.module import LayerSpec
from hqq.core import quantize as hqq_quantize
from torch.utils.tensorboard import SummaryWriter
import dataloader
import engine
import hqq_utils
import models
import unsloth_utils
from dataset_utils import load_datasets
from peft import LoraConfig, get_peft_model
from peft.optimizers import create_loraplus_optimizer
from saver import Saver
from utils import DTYPE_MAP, is_main_process
parser = argparse.ArgumentParser()
parser.add_argument('--config', help='Path to TOML configuration file.')
parser.add_argument('--local_rank', type=int, default=-1, help='local rank passed from distributed launcher')
parser.add_argument('--debug_dataset', type=int, help='print out this many training examples and then quit')
parser.add_argument(
'--resume_from_checkpoint',
action='store_true',
default=None,
help='resume training from the most recent checkpoint',
)
parser.add_argument('--no_quantiles', action='store_true', help='suppress output of quantile metrics')
parser = deepspeed.add_config_arguments(parser)
args = parser.parse_args()
def print_model_info(model):
if not is_main_process():
return
print(model)
for name, module in model.named_modules():
print(f'{type(module)}: {name}')
for pname, p in module.named_parameters(recurse=False):
print(pname)
print(p.dtype)
print(p.device)
print(p.requires_grad)
print()
def set_config_defaults(config):
config['full_fine_tune'] = config.get('full_fine_tune', False)
config['load_in_4bit'] = config.get('load_in_4bit', False)
def get_most_recent_run_dir(output_dir):
return sorted(glob.glob(os.path.join(output_dir, '*')))[-1]
def write_metrics(tb_writer, prefix, metrics, step):
loss = metrics[0].mean().item()
tb_writer.add_scalar(f'{prefix}/loss', loss, step)
if len(metrics) > 1:
losses = metrics[1].view(-1)
positive_losses = losses > 0
tb_writer.add_histogram(f'{prefix}/log_loss_hist', torch.log(losses[positive_losses]), step)
if not args.no_quantiles:
sorted_losses, sorted_losses_idx = torch.sort(losses)
quantiles = torch.tensor(
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95, 0.96, 0.97, 0.98, 0.99, 0.999], dtype=torch.float32
).to(losses.device)
quantiles_idx = [int(len(losses) * quantile) for quantile in quantiles]
loss_quantiles = [sorted_losses[i] for i in quantiles_idx]
for quantile, value in zip(quantiles, loss_quantiles):
tb_writer.add_scalar(f'{prefix}/loss_quantile_{quantile:.3f}', value, step)
if len(metrics) > 2:
hidden_norm_avg = metrics[2].mean().item()
tb_writer.add_scalar(f'{prefix}/hidden_norm_avg', hidden_norm_avg, step)
hidden_state_norms = metrics[2].view(-1)
tb_writer.add_histogram(f'{prefix}/hidden_norm_hist', hidden_state_norms, step)
if len(metrics) > 3:
entropy = metrics[3].view(-1)
tb_writer.add_scalar(f'{prefix}/entropy', entropy.mean().item(), step)
if not args.no_quantiles:
assert entropy.size() == losses.size(), (entropy.size(), losses.size())
sorted_entropy = entropy[sorted_losses_idx]
entropy_quantiles = []
for i, j in itertools.zip_longest(quantiles_idx, quantiles_idx[1:]):
entropy_quantiles.append(sorted_entropy[i:j].mean())
for quantile, value in zip(quantiles, entropy_quantiles):
tb_writer.add_scalar(f'{prefix}/entropy_quantile_{quantile:.3f}', value, step)
if len(metrics) > 4:
normalised_entropy = metrics[4].view(-1)
tb_writer.add_scalar(f'{prefix}/normalised_entropy', normalised_entropy.mean().item(), step)
if not args.no_quantiles:
assert normalised_entropy.size() == losses.size()
sorted_normalised_entropy = normalised_entropy[sorted_losses_idx]
normalised_entropy_quantiles = []
for i, j in itertools.zip_longest(quantiles_idx, quantiles_idx[1:]):
normalised_entropy_quantiles.append(sorted_normalised_entropy[i:j].mean())
for quantile, value in zip(quantiles, normalised_entropy_quantiles):
tb_writer.add_scalar(f'{prefix}/normalised_entropy_quantile_{quantile:.3f}', value, step)
if len(metrics) > 5:
log_likelihood = metrics[5].mean()
tb_writer.add_scalar(f'{prefix}/log_likelihood', log_likelihood.item(), step)
likelihood = torch.exp(-log_likelihood).item()
tb_writer.add_scalar(f'{prefix}/likelihood', likelihood, step)
perplexity = torch.exp(log_likelihood).item()
tb_writer.add_scalar(f'{prefix}/perplexity', perplexity, step)
if len(metrics) > 6:
mcfaddens_pseudo_r2 = metrics[6].mean()
tb_writer.add_scalar(f'{prefix}/mcfaddens_pseudo_r2', mcfaddens_pseudo_r2.item(), step)
if len(metrics) > 7:
tb_writer.add_scalar(f'{prefix}/top1_accuracy', metrics[7].mean().item(), step)
tb_writer.add_scalar(f'{prefix}/top5_accuracy', metrics[8].mean().item(), step)
tb_writer.add_scalar(f'{prefix}/top20_accuracy', metrics[9].mean().item(), step)
if len(metrics) > 10:
tb_writer.add_scalar(f'{prefix}/load_balancing_loss', metrics[10].mean().item(), step)
if len(metrics) > 11:
tb_writer.add_scalar(f'{prefix}/alternate_load_balancing_loss', metrics[11].mean().item(), step)
return loss
def evaluate_single(model_engine, name, eval_dataloader, tb_writer, step, eval_gradient_accumulation_steps):
orig_micro_batches = model_engine.micro_batches
model_engine.micro_batches = eval_gradient_accumulation_steps
iterator = iter(eval_dataloader)
all_metrics = None
while True:
metrics = model_engine.eval_batch(iterator)
eval_dataloader.sync_epoch()
if all_metrics is None:
all_metrics = [[] for _ in range(len(metrics))]
if eval_dataloader.epoch == 2:
break
for i, metric in enumerate(metrics):
all_metrics[i].append(metric)
eval_dataloader.reset()
model_engine.micro_batches = orig_micro_batches
eval_metrics = [torch.cat(metric_list) for metric_list in all_metrics]
loss = None
if is_main_process():
loss = write_metrics(tb_writer, f'eval/{name}', eval_metrics, step)
return loss
def evaluate(model_engine, eval_dataloaders, tb_writer, step, eval_gradient_accumulation_steps):
if is_main_process():
print('Running eval')
start = time.time()
loss = []
for name, eval_dataloader in eval_dataloaders.items():
loss_or_none = evaluate_single(
model_engine, name, eval_dataloader, tb_writer, step, eval_gradient_accumulation_steps
)
if loss_or_none is not None:
loss.append(loss_or_none)
duration = time.time() - start
if is_main_process():
tb_writer.add_scalar('eval/eval_time_sec', duration, step)
return sum(loss) / len(loss) if len(loss) > 0 else None
def apply_max_norm_regularization(model, config):
# modifed from https://github.com/kohya-ss/sd-scripts/blob/main/networks/lora.py
A_keys = []
B_keys = []
norms = []
keys_scaled = 0
lora_scale = config['lora_alpha'] / config['lora_rank']
state_dict = model.state_dict()
for key in state_dict.keys():
if 'lora_A' in key:
A_keys.append(key)
B_keys.append(key.replace('lora_A', 'lora_B'))
for i in range(len(A_keys)):
A = state_dict[A_keys[i]]
B = state_dict[B_keys[i]]
W = B @ A
W *= lora_scale
if 'scale_weight_norms' in config:
max_norm = config['scale_weight_norms']
norm = W.norm().clamp(min=max_norm / 2)
desired = torch.clamp(norm, max=max_norm)
ratio = desired.cpu() / norm.cpu()
sqrt_ratio = ratio**0.5
if ratio != 1:
keys_scaled += 1
state_dict[A_keys[i]] *= sqrt_ratio
state_dict[B_keys[i]] *= sqrt_ratio
else:
ratio = 1.0
scalednorm = W.norm() * ratio
norms.append(scalednorm.item())
if len(norms) > 0:
norms = torch.tensor(norms, dtype=torch.float32)
if torch.any(torch.isnan(norms)):
raise RuntimeError('NaN detected in norms, probably some/all weights are NaN')
avg_norm = sum(norms) / len(norms)
max_norm = max(norms)
else:
avg_norm = 0
max_norm = 0
return keys_scaled, avg_norm, max_norm, norms
def parse_layers_to_transform(spec):
parts = spec.split(',')
result = []
for part in parts:
start, stop = part.split(':')
result.extend(range(int(start), int(stop) + 1))
return result
@contextmanager
def one_at_a_time():
for i in range(int(os.environ['LOCAL_SIZE'])):
if i == int(os.environ['LOCAL_RANK']):
yield
deepspeed.comm.barrier()
def load_pipeline_model_with_lora(config, model_type, dynamic_shape=False):
full_fine_tune = config['full_fine_tune']
if config.get('quantization', None):
assert not full_fine_tune
no_quant_modules = ['lm_head']
if model_type == 'mixtral':
# the expert routing weights are tiny and probably important, don't quantize
no_quant_modules.append('gate')
if bnb_quant_config := config['quantization'].get('bnb', None):
if bnb_compute_dtype := bnb_quant_config.get('bnb_4bit_compute_dtype', None):
bnb_quant_config['bnb_4bit_compute_dtype'] = DTYPE_MAP[bnb_compute_dtype]
if 'bnb_4bit_quant_type' not in bnb_quant_config:
# Always want to default to nf4 if not specified.
bnb_quant_config['bnb_4bit_quant_type'] = 'nf4'
if llm_int8_skip_modules := bnb_quant_config.get('llm_int8_skip_modules', None):
no_quant_modules.extend(llm_int8_skip_modules)
no_quant_modules = list(set(no_quant_modules))
bnb_quant_config['llm_int8_skip_modules'] = no_quant_modules
quantization_config = transformers.BitsAndBytesConfig(**bnb_quant_config)
elif hqq_quant_config := config['quantization'].get('hqq', None):
quantization_config = hqq_utils.CustomHQQConfig(**hqq_quant_config)
# Use ATEN backend if possible, else PYTORCH. PYTORCH_COMPILE was only a tiny bit faster, and requires triton nightly.
hqq_quantize.HQQLinear.set_backend(
hqq_quantize.HQQBackend.ATEN if quantization_config.use_aten() else hqq_quantize.HQQBackend.PYTORCH
)
else:
raise NotImplementedError('Invalid quantization config')
if is_main_process():
print(f'Quantization config: {quantization_config}')
else:
quantization_config = None
if model_type == 'llama':
model = models.LlamaForCausalLMPipe(config, quantization_config=quantization_config)
elif model_type == 'mixtral':
model = models.MixtralForCausalLMPipe(config, quantization_config=quantization_config)
elif model_type == 'qwen2':
model = models.Qwen2ForCausalLMPipe(config, quantization_config=quantization_config)
elif model_type == 'cohere':
model = models.CohereForCausalLMPipe(config, quantization_config=quantization_config)
elif model_type == 'phi3':
model = models.Phi3ForCausalLMPipe(config, quantization_config=quantization_config)
elif model_type == 'gemma2':
model = models.Gemma2ForCausalLMPipe(config, quantization_config=quantization_config)
elif model_type == 'mistral':
model = models.MistralForCausalLMPipe(config, quantization_config=quantization_config)
else:
raise NotImplementedError()
# CAREFUL! The "primary" layers of the model have to have 'decoderlayer' in them for
# activation checkpointing to automatically work correctly.
layers = model.to_layer_specs()
checkpointable_layers = set()
for layer in layers:
if isinstance(layer, LayerSpec) and 'decoderlayer' in layer.typename.__name__.lower():
checkpointable_layers.add(layer.typename.__name__)
checkpointable_layers = list(checkpointable_layers)
partition_method = 'estimated_size'
if config['activation_checkpointing']:
# NOTE: must use a reentrant checkpointing function for MLP offloading to work.
if config['activation_checkpointing'] == 'unsloth':
checkpoint_func = unsloth_utils.unsloth_checkpoint
elif config['activation_checkpointing'] == 'cpu':
deepspeed.checkpointing.configure(None, checkpoint_in_cpu=True)
checkpoint_func = deepspeed.checkpointing.checkpoint
else:
checkpoint_func = deepspeed.checkpointing.checkpoint
pipeline_model = engine.CustomPipelineModule(
layers=layers,
num_stages=config['pipeline_stages'],
activation_checkpoint_interval=1,
checkpointable_layers=checkpointable_layers,
activation_checkpoint_func=checkpoint_func,
partition_method=partition_method,
use_column_major_topology=config.get('use_column_major_topology', False),
model=model,
dynamic_shape=dynamic_shape,
)
else:
pipeline_model = engine.CustomPipelineModule(
layers=layers,
num_stages=config['pipeline_stages'],
partition_method=partition_method,
use_column_major_topology=config.get('use_column_major_topology', False),
)
target_modules = config['target_modules'] if 'target_modules' in config else 'all-linear'
if full_fine_tune:
lora_model = None
lora_config = None
for name, p in model.named_parameters():
p.original_name = name
if isinstance(target_modules, list):
for name, p in model.named_parameters():
if not any(target in name for target in config['target_modules']):
p.requires_grad = False
print(f'not training {name} because it is not present in target_modules')
else:
layers_to_transform = (
parse_layers_to_transform(config['layers_to_transform']) if 'layers_to_transform' in config else None
)
lora_config = LoraConfig(
r=config['lora_rank'],
lora_alpha=config['lora_alpha'],
target_modules=target_modules,
modules_to_save=config['modules_to_save'] if 'modules_to_save' in config else [],
lora_dropout=config['lora_dropout'] if 'lora_dropout' in config else 0,
layers_to_transform=layers_to_transform,
bias='none',
task_type='CAUSAL_LM',
use_dora=config.get('use_dora', False),
)
lora_model = get_peft_model(model, lora_config)
# If the underlying weights are floats, the lora weights have already been
# cast to the same dtype, so we need to change the dtype here.
for p in lora_model.parameters():
if p.requires_grad:
p.data = p.data.to(DTYPE_MAP[config.get('lora_weight_dtype', 'float32')])
lora_model.model.config.use_cache = False
for name, p in lora_model.named_parameters():
p.original_name = name
return pipeline_model, lora_model, lora_config
if __name__ == '__main__':
# TODO: if resuming from checkpoint, probably should read all config files from checkpoint dir
# rather than assume they are unchanged on the command line
with open(args.config) as f:
config = toml.load(f)
set_config_defaults(config)
if hasattr(args, 'deepspeed_config') and args.deepspeed_config is not None:
# engine.initialize() will load deepspeed config from args
ds_config = None
else:
# The necessary ds_config fields are taken from the TOML config file.
ds_config = {
'train_micro_batch_size_per_gpu': config.get('micro_batch_size_per_gpu', 1),
'gradient_accumulation_steps': config.get('gradient_accumulation_steps', 1),
'gradient_clipping': config.get('gradient_clipping', 1.0),
'steps_per_print': config.get('steps_per_print', 1),
}
resume_from_checkpoint = (
args.resume_from_checkpoint
if args.resume_from_checkpoint is not None
else config['resume_from_checkpoint']
if 'resume_from_checkpoint' in config
else False
)
deepspeed.init_distributed()
with open(os.path.join(config['model'], 'config.json')) as f:
model_config = json.load(f)
model_type = model_config.get('model_type', 'llama')
# Pad on left to support training techniques that involve sampling from the model.
tokenizer = transformers.AutoTokenizer.from_pretrained(
config['model'], local_files_only=True, model_max_length=int(1e30), padding_side='left'
)
# TODO: do we want to do this with cohere models? By default the EOS token is <|END_OF_TURN_TOKEN|>
# if model_type == 'cohere':
# tokenizer.eos_token = '<EOS_TOKEN>'
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
train_data, eval_data_map = load_datasets(config, tokenizer)
if args.debug_dataset:
if is_main_process():
for i, item in enumerate(iter(train_data)):
print('input_ids:')
print(item['input_ids'][:1000])
print('decoded input_ids:')
print(tokenizer.decode(item['input_ids'][:1000]))
print('attention_mask:')
print(item['attention_mask'][:1000])
print('labels:')
print(item['labels'][:1000])
if 'rejected_input_ids' in item:
print('input_ids:')
print(item['rejected_input_ids'][:1000])
print('decoded rejected_input_ids:')
print(tokenizer.decode(item['rejected_input_ids'][:1000]))
print('rejected_attention_mask:')
print(item['rejected_attention_mask'][:1000])
print('rejected_labels:')
print(item['rejected_labels'][:1000])
print('-' * 80)
if i >= args.debug_dataset - 1:
break
quit()
# for testing
# train_data = train_data.select(list(range(100)))
# if this is a new run, create a new dir for it
if not resume_from_checkpoint and is_main_process():
run_dir = os.path.join(config['output_dir'], datetime.now(timezone.utc).strftime('%Y%m%d_%H-%M-%S'))
os.makedirs(run_dir, exist_ok=True)
shutil.copy(args.config, run_dir)
if hasattr(args, 'deepspeed_config') and args.deepspeed_config is not None:
shutil.copy(args.deepspeed_config, run_dir)
# wait for all processes then get the most recent dir (may have just been created)
deepspeed.comm.barrier()
run_dir = get_most_recent_run_dir(config['output_dir'])
# Ugly hack so we can move quantized models from GPU to CPU, and back to GPU again without triggering quantization a second time.
bnb_cuda_old = bitsandbytes.nn.modules.Params4bit.cuda
def bnb_cuda_hijack(self, device):
if getattr(self, 'already_quantized', False):
self.data = self.data.to(device)
self.quant_state.to(device)
return self
self.already_quantized = True
return bnb_cuda_old(self, device)
bitsandbytes.nn.modules.Params4bit.cuda = bnb_cuda_hijack
pipeline_model, lora_model, lora_config = load_pipeline_model_with_lora(config, model_type)
parameters_to_train = [p for p in pipeline_model.parameters() if p.requires_grad]
optim_config = config['optimizer']
def get_optimizer(model_parameters):
lr = optim_config['lr']
optim_type = optim_config['type'].lower()
optimizer_kwargs = {
'params': model_parameters,
'lr': lr,
'betas': (optim_config.get('beta1', 0.9), optim_config.get('beta2', 0.99)),
'weight_decay': optim_config.get('weight_decay', 0.01),
'eps': optim_config.get('eps', 1e-6),
}
if optim_type == 'adamw':
optimizer_cls = deepspeed.ops.adam.FusedAdam
elif optim_type == 'adamw8bit':
optimizer_cls = bitsandbytes.optim.AdamW8bit
elif optim_type == 'adamw_kahan':
import optimi
optimizer_cls = optimi.AdamW
optimizer_kwargs['kahan_sum'] = optim_config.get('kahan_sum', True)
else:
raise NotImplementedError(optim_type)
if optim_config.get('use_loraplus', False):
loraplus_lr_ratio = optim_config.get('loraplus_lr_ratio', 16)
# TODO: handle params being thrown out here; why is it included in the first place?
# delete 'params' from optimizer_kwargs
del optimizer_kwargs['params']
return create_loraplus_optimizer(
model=pipeline_model,
optimizer_cls=optimizer_cls,
loraplus_lr_ratio=loraplus_lr_ratio,
**optimizer_kwargs,
)
return optimizer_cls(**optimizer_kwargs)
model_engine, optimizer = engine.initialize(
args=args,
model=pipeline_model,
model_parameters=parameters_to_train,
optimizer=get_optimizer,
lora_model=lora_model,
config=ds_config,
)
if rl_config := config.get('rl', None):
model_engine.configure_rl(rl_config)
# TODO: I have recently realized that we are setting things to fp16/bf16, even though all the DS
# config was not in fp16 / bf16 mode. DS being in fp16/bf16 changes things in many places, e.g.
# it can give you a BF16_Optimizer wrapper that accumulates grads in fp32, the communication dtype
# is different, etc. I need to really look through all the implications of this. This change is so
# that we keep the normal optimizer, but the communication dtype is changed so that we don't
# unnecessarily cast grads to fp32.
weight_dtype = DTYPE_MAP[config.get('lora_weight_dtype', config.get('model_weight_dtype', 'float32'))]
model_engine.communication_data_type = weight_dtype
# TODO: the main DeepSpeedEngine forces all parameters to the GPU, and also does things like
# broadcast all parameters from data parallel rank 0 to all other ranks. Thus, MLP offloading
# must come after engine.initialize(). If we want to avoid loading everything onto GPUs only
# to offload the MLPs, we have to rewrite a lot of code to work around things.
if config.get('offload_mlp_to_cpu', False):
assert config['activation_checkpointing'] # MLP offloading only works with activation checkpointing
for module in pipeline_model.modules():
if hasattr(module, 'move_mlp_to_cpu'):
module.move_mlp_to_cpu()
torch.cuda.empty_cache()
train_dataloader = dataloader.PipelineDataLoader(
train_data,
tokenizer,
model_engine.train_micro_batch_size_per_gpu(),
model_engine.gradient_accumulation_steps(),
model_engine.grid.get_data_parallel_world_size(),
model_engine.grid.get_data_parallel_rank(),
group_by_length=False if 'group_by_length' not in config else config['group_by_length'],
batch_size_tokens=None if 'batch_size_tokens' not in config else config['batch_size_tokens'],
)
model_engine.set_dataloader(train_dataloader)
steps_per_epoch = len(train_dataloader) // model_engine.gradient_accumulation_steps()
model_engine.total_steps = steps_per_epoch * config['epochs']
if is_main_process():
# Warn if eval dataset is unusually large compared to the eval steps
eval_data_length = sum([len(eval_data) for eval_data in eval_data_map.values()])
train_data_length = len(train_data)
evals_per_epoch = steps_per_epoch / config['eval_steps']
relative_eval_time = evals_per_epoch * eval_data_length
# train step very roughly 3 times slower due to backprop + usually activation checkpointing is enabled
relative_train_time = train_data_length * 3
# Expect <=15% of our time spent evaluating vs training
fraction_evaling = relative_eval_time / (relative_eval_time + relative_train_time)
print()
print(
f'eval_data_length: {eval_data_length}, eval_steps: {config["eval_steps"]}; evals per epoch: {evals_per_epoch}. '
f'We will be spending approximately {fraction_evaling * 100:.2f}% of our time evaluating.'
)
if fraction_evaling > 0.15:
print(
'WARNING: eval dataset is unusually large compared to eval_steps. We will spend a lot of time evaluating. Lowering eval_size and/or bumping eval_steps is recommended.'
)
print()
# handle Deepspeed optimizer wrapper (e.g. BF16_Optimizer)
optimizer = getattr(optimizer, 'optimizer', optimizer)
warmup_steps = config.get('warmup_steps', 0)
# Fractional values less than 1 are converted into "fraction of epoch" worth of steps
if 0 < warmup_steps < 1:
warmup_steps = int(warmup_steps * steps_per_epoch)
if 'lr_scheduler' not in config or config['lr_scheduler'] == 'constant' or config['lr_scheduler'] == 'none':
lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0)
elif config['lr_scheduler'] == 'cosine':
total_steps = steps_per_epoch * config['epochs']
total_steps -= warmup_steps
lr_scheduler_kwargs = {
'optimizer': optimizer,
'T_max': total_steps,
}
if 'lr_min' in optim_config:
lr_scheduler_kwargs['eta_min'] = optim_config['lr_min']
# Normally, you would pass the lr_scheduler to deepspeed.initialize(). But we need the
# global batch_size in order to make the lr_scheduler.
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(**lr_scheduler_kwargs)
else:
raise NotImplementedError()
load_optimizer_states = config.get('load_optimizer_states', True)
# if resuming and not loading optimizer states, we can't use warmup or the LR never changes from the initial value (still don't know why)
if warmup_steps > 0 and load_optimizer_states:
warmup_scheduler = torch.optim.lr_scheduler.LinearLR(
optimizer, start_factor=1 / warmup_steps, total_iters=warmup_steps
)
lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
optimizer, schedulers=[warmup_scheduler, lr_scheduler], milestones=[warmup_steps]
)
model_engine.lr_scheduler = lr_scheduler
step = 1
if resume_from_checkpoint:
load_path, client_state = model_engine.load_checkpoint(
run_dir,
load_module_strict=False,
load_lr_scheduler_states='force_constant_lr' not in config,
load_optimizer_states=load_optimizer_states,
)
deepspeed.comm.barrier() # just so the print below doesn't get swamped
assert load_path is not None
train_dataloader.load_state_dict(client_state['custom_loader'])
step = client_state['step'] + 1
del client_state
# if we skip loading the optimizer states, we need to step the LR scheduler so we start at the right value
if not load_optimizer_states:
model_engine.lr_scheduler.step()
if is_main_process():
print(f'Resuming training from checkpoint. Resuming at epoch: {train_dataloader.epoch}, step: {step}')
if 'force_constant_lr' in config:
model_engine.lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0)
for pg in optimizer.param_groups:
pg['lr'] = config['force_constant_lr']
# this is a separate option, because if it's too high we might drop a significant fraction of the eval dataset
eval_gradient_accumulation_steps = (
config['eval_gradient_accumulation_steps'] if 'eval_gradient_accumulation_steps' in config else 1
)
# Eval dataset doesn't need to repeat; we just use this to track "epoch" so we know when we're done iterating over it.
eval_dataloaders = {
name: dataloader.PipelineDataLoader(
eval_data,
tokenizer,
model_engine.train_micro_batch_size_per_gpu(),
eval_gradient_accumulation_steps,
model_engine.grid.get_data_parallel_world_size(),
model_engine.grid.get_data_parallel_rank(),
shuffle=False,
group_by_length=False if 'group_by_length' not in config else config['group_by_length'],
batch_size_tokens=None if 'batch_size_tokens' not in config else config['batch_size_tokens'],
)
for name, eval_data in eval_data_map.items()
}
tb_writer = SummaryWriter(log_dir=run_dir) if is_main_process() else None
epoch = train_dataloader.epoch
saver = Saver(model_engine, pipeline_model, train_dataloader, lora_config, run_dir, args, config)
epoch = train_dataloader.epoch
if config.get('eval_before_first_step', False) and not resume_from_checkpoint:
loss = evaluate(model_engine, eval_dataloaders, tb_writer, 0, eval_gradient_accumulation_steps)
saver.append_eval_results(loss, save_best=False)
while True:
metrics = model_engine.train_batch()
train_dataloader.sync_epoch()
if lora_config is not None:
keys_scaled, avg_norm, max_norm, norms = apply_max_norm_regularization(pipeline_model, config)
new_epoch = saver.process_epoch(epoch, step)
finished_epoch = True if new_epoch != epoch else False
if is_main_process() and step % config['logging_steps'] == 0:
write_metrics(tb_writer, 'train', metrics, step)
tb_writer.add_scalar('train/lr', optimizer.param_groups[0]['lr'], step)
# TODO: gather the weight norms across all stages in the pipelined model, not just the first.
if lora_config is not None and len(norms) > 0:
tb_writer.add_scalar('train/weights_scaled', keys_scaled, step)
tb_writer.add_scalar('train/weight_norm_avg', avg_norm, step)
tb_writer.add_scalar('train/weight_norm_max', max_norm, step)
tb_writer.add_histogram('train/weight_norm_hist', norms, step)
tb_writer.add_scalar('train/epoch', step / steps_per_epoch, step)
if step % config['eval_steps'] == 0:
loss = evaluate(model_engine, eval_dataloaders, tb_writer, step, eval_gradient_accumulation_steps)
saver.append_eval_results(loss)
saver.process_step(step)
if finished_epoch:
epoch = new_epoch
if epoch is None:
break
step += 1
if ((step - 1) % config['eval_steps'] != 0) and config.get('eval_after_last_step', False):
loss = evaluate(model_engine, eval_dataloaders, tb_writer, step - 1, eval_gradient_accumulation_steps)
saver.append_eval_results(loss)
if is_main_process():
print('TRAINING COMPLETE!')