1
1
import json
2
- import logging
3
2
import os
4
- from pathlib import Path
5
- from typing import Dict , Iterator , Optional , OrderedDict , Tuple
3
+ from typing import Dict
6
4
7
5
import torch
8
6
import torch .distributed as dist
9
7
import torch .nn as nn
10
8
from torch .distributed .distributed_c10d import _get_default_group
11
9
12
10
from colossalai .interface import ModelWrapper
13
- from colossalai .utils import get_non_persistent_buffers_set
11
+ from colossalai .shardformer .layer .parallel_module import ParallelModule
12
+ from contextlib import contextmanager
14
13
15
- from .index_file import CheckpointIndexFile
16
14
from .utils import (
17
- StateDictSharder ,
18
- async_save_state_dict_shards ,
19
- create_pinned_state_dict ,
20
- get_model_base_filenames ,
21
15
load_state_dict ,
22
- save_state_dict ,
23
- save_state_dict_shards ,
24
16
search_tp_partition_dim ,
25
17
)
26
18
27
- try :
28
- from torch .nn .modules .module import _EXTRA_STATE_KEY_SUFFIX
29
- except ImportError :
30
- _EXTRA_STATE_KEY_SUFFIX = "_extra_state"
31
-
32
19
MODEL_META_PREFIX = "pytorch_model-meta-dist-"
33
20
MODEL_WEIGHT_PREFIX = "pytorch_model-dist-"
34
21
SHARD_META_SUFFIX = ".index.json"
22
+ UNSHARD_META_SUFFIX = ".json"
35
23
36
24
37
- def dist_model_state_dict (model : nn .Module , prefix : str = "" , keep_vars : bool = False ):
38
- destination = dict ()
39
- # Save parameters.
40
- for name , param in model .named_parameters ():
41
- if param is None :
42
- continue
43
- destination [prefix + name ] = param
44
- # Save buffers.
45
- non_persist_buffers_set = get_non_persistent_buffers_set (model )
46
- for name , buf in model .named_buffers ():
47
- if buf is not None and name not in non_persist_buffers_set :
48
- buffer = buf if keep_vars else buf .detach ()
49
- destination [prefix + name ] = buffer
50
-
51
- # Save extra states.
52
- extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
53
- if (
54
- getattr (model .__class__ , "get_extra_state" , torch .nn .Module .get_extra_state )
55
- is not torch .nn .Module .get_extra_state
56
- ):
57
- extra_state = model .get_extra_state ()
58
- destination [extra_state_key ] = extra_state
59
- return destination
60
-
61
-
62
- def load_state_dict_into_dist_model (
63
- model : nn .Module , state_dict : Dict , prefix : str = "" , keep_vars : bool = False , strict : bool = False
64
- ):
65
- destination = dict ()
66
- # Save parameters.
67
- for name , param in model .named_parameters ():
68
- if param is None :
69
- continue
70
- with torch .no_grad ():
71
- param .copy_ (state_dict [prefix + name ])
72
- # Save buffers.
73
- non_persist_buffers_set = get_non_persistent_buffers_set (model )
74
- for name , buf in model .named_buffers ():
75
- if buf is not None and name not in non_persist_buffers_set :
76
- with torch .no_grad ():
77
- buf .copy_ (state_dict [prefix + name ])
78
-
79
- # Save extra states.
80
- extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
81
- if (
82
- getattr (model .__class__ , "get_extra_state" , torch .nn .Module .get_extra_state )
83
- is not torch .nn .Module .get_extra_state
84
- ):
85
- extra_state = model .get_extra_state ()
86
- with torch .no_grad ():
87
- extra_state .copy_ (state_dict [extra_state_key ])
88
- return destination
25
+ @contextmanager
26
+ def RestoreDefaultStateDictBehavior (model ):
27
+ original_methods = {}
28
+ for name , module in model .named_modules ():
29
+ if isinstance (module , ParallelModule ):
30
+ original_methods [module ] = (module ._save_to_state_dict , module ._load_from_state_dict )
31
+ module ._save_to_state_dict = nn .Module ._save_to_state_dict .__get__ (module , nn .Module )
32
+ module ._load_from_state_dict = nn .Module ._load_from_state_dict .__get__ (module , nn .Module )
33
+ try :
34
+ yield model
35
+ finally :
36
+ for module , original_method in original_methods .items ():
37
+ module ._save_to_state_dict , module ._load_from_state_dict = original_method
38
+
89
39
90
40
91
41
def create_model_metadata (
92
- model : nn . Module ,
42
+ model : ModelWrapper ,
93
43
prefix : str = "" ,
94
- tp_size = None ,
95
- tp_rank = None ,
44
+ tp_size : int = None ,
45
+ tp_rank : int = None ,
46
+ zero_size : int = None ,
47
+ zero_rank : int = None ,
96
48
):
97
49
param_origin_shape = model .param_origin_shape
98
50
model = model .unwrap ()
@@ -105,7 +57,7 @@ def create_model_metadata(
105
57
tp_partition_dim = search_tp_partition_dim (
106
58
current_shape = param .shape , original_shape = original_shape , tp_size = tp_size
107
59
)
108
- model_metadata [prefix + name ]["offsets" ] = torch . zeros ( len (original_shape ), dtype = torch . int )
60
+ model_metadata [prefix + name ]["offsets" ] = [ 0 ] * len (original_shape )
109
61
model_metadata [prefix + name ]["lengths" ] = list (param .shape )
110
62
model_metadata [prefix + name ]["global_shape" ] = list (original_shape )
111
63
if tp_partition_dim is not None :
@@ -257,119 +209,9 @@ def is_pytorch_model_meta_dist_file(checkpoint_index_file):
257
209
return False
258
210
259
211
260
- def dist_model_sharder (
261
- model : nn .Module ,
262
- prefix : str = "" ,
263
- keep_vars : bool = False ,
264
- size_per_shard : int = 1024 ,
265
- pinned_state_dicts : Optional [Dict [str , torch .Tensor ]] = None ,
266
- ) -> Iterator [Tuple [OrderedDict , int ]]:
267
- # An internel method that breaks state_dict of model into shards within limited size.
268
-
269
- state_dict_sharder = StateDictSharder (size_per_shard )
270
-
271
- # Save parameters.
272
- for name , param in model .named_parameters ():
273
- if param is None :
274
- continue
275
- if pinned_state_dicts is not None :
276
- if (prefix + name ) not in pinned_state_dicts :
277
- pinned_state_dicts [prefix + name ] = torch .empty_like (param , pin_memory = True , device = "cpu" )
278
- pinned_state_dicts [prefix + name ].copy_ (param )
279
- param = pinned_state_dicts [prefix + name ]
280
- block , block_size = state_dict_sharder .append_param (prefix + name , param )
281
- if block is not None :
282
- yield block , block_size
283
-
284
- # Save buffers.
285
- non_persist_buffers_set = get_non_persistent_buffers_set (model )
286
- for name , buf in model .named_buffers ():
287
- if buf is not None and name not in non_persist_buffers_set :
288
- buffer = buf if keep_vars else buf .detach ()
289
- if pinned_state_dicts is not None :
290
- if (prefix + name ) not in pinned_state_dicts :
291
- pinned_state_dicts [prefix + name ] = torch .empty_like (buffer , pin_memory = True , device = "cpu" )
292
- pinned_state_dicts [prefix + name ].copy_ (buffer )
293
- buffer = pinned_state_dicts [prefix + name ]
294
- block , block_size = state_dict_sharder .append_param (prefix + name , buffer )
295
- if block is not None :
296
- yield block , block_size
297
-
298
- # Save extra states.
299
- extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
300
- if (
301
- getattr (model .__class__ , "get_extra_state" , torch .nn .Module .get_extra_state )
302
- is not torch .nn .Module .get_extra_state
303
- ):
304
- extra_state = model .get_extra_state ()
305
- if pinned_state_dicts is not None :
306
- if extra_state_key not in pinned_state_dicts :
307
- pinned_state_dicts [extra_state_key ] = torch .empty_like (extra_state , pin_memory = True , device = "cpu" )
308
- pinned_state_dicts [extra_state_key ].copy_ (extra_state )
309
- extra_state = pinned_state_dicts [extra_state_key ]
310
- block , block_size = state_dict_sharder .append_param (extra_state_key , extra_state )
311
- if block is not None :
312
- yield block , block_size
313
-
314
- # Return the last block in sharder.
315
- yield state_dict_sharder .current_block , state_dict_sharder .current_block_size
316
-
317
-
318
- def save_dist_unshard_model (
319
- model : ModelWrapper ,
320
- model_metadata : Dict ,
321
- checkpoint : str ,
322
- use_safetensors : bool ,
323
- use_async : bool = False ,
324
- dist_id = 0 ,
325
- pinned_state_dicts = None ,
326
- ):
327
- """
328
- Save model state dict to a single file with given checkpointing path.
329
-
330
- Args:
331
- model (nn.Module): Model on local device to be saved.
332
- checkpoint (str): Checkpointing path which should be a file path. Can be absolute or relative path.
333
- gather_dtensor (bool, optional): Whether to gather dtensor, currently not used. Defaults to True.
334
- use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
335
- use_async (bool, optional): Whether to save the state_dicts of model asynchronously. Defaults to False.
336
- """
337
-
338
- model = model .unwrap ()
339
-
340
- # The logic of collecting parameter shards along tp degree
341
- # has been implemented by _save_to_state_dict method of ParallelModule in Shardformer.
342
- state_dict = dist_model_state_dict (model )
343
-
344
- Path (checkpoint ).mkdir (parents = True , exist_ok = True )
345
- file_name = f"{ MODEL_WEIGHT_PREFIX } { dist_id :05d} .bin"
346
- if use_async :
347
- file_name = file_name .replace (".bin" , ".safetensors" )
348
- checkpoint_file = os .path .join (checkpoint , file_name )
349
- metadata_file = os .path .join (checkpoint , f"{ MODEL_META_PREFIX } { dist_id :05d} .json" )
350
- save_metadata (model_metadata , metadata_file , file_name )
351
-
352
- if use_async :
353
- from colossalai .utils .safetensors import save
354
-
355
- if id (model ) not in pinned_state_dicts :
356
- pinned_state_dicts [id (model )] = create_pinned_state_dict (state_dict )
357
- for name , param in state_dict .items ():
358
- pinned_state_dicts [id (model )][name ].copy_ (param )
359
- state_dict [name ] = pinned_state_dicts [id (model )][name ]
360
- writer = save (path = checkpoint_file , state_dict = state_dict )
361
- return writer
362
- else :
363
- save_state_dict (state_dict , checkpoint_file , use_safetensors )
364
- return None
365
-
366
-
367
212
def load_dist_model (
368
- model : ModelWrapper ,
369
213
model_metadata : Dict ,
370
214
checkpoint : str ,
371
- low_cpu_mem_mode : bool = True ,
372
- num_threads : int = 1 ,
373
215
):
374
216
"""
375
217
Load model from a single file with the given path of checkpoint.
@@ -380,10 +222,6 @@ def load_dist_model(
380
222
strict (bool, optional): For name matching during loading state_dict. Defaults to False.
381
223
This argument should be manually set to False since not all params in checkpoint are needed for each device when pipeline is enabled.
382
224
"""
383
-
384
- model_before_wrapping = model
385
- model = model .unwrap ()
386
-
387
225
metadata_loaded = load_metadata (checkpoint )
388
226
389
227
load_files = {}
@@ -420,92 +258,14 @@ def load_dist_model(
420
258
)
421
259
state_dict [key ] = state
422
260
423
- if not low_cpu_mem_mode :
424
- state_dict = create_pinned_state_dict (state_dict , empty = False , num_threads = num_threads )
425
-
426
- load_state_dict_into_dist_model (model = model , state_dict = state_dict )
427
-
428
- # Update master params if mixed-precision training is enabled.
429
- model_before_wrapping .update_master_params ()
261
+ return state_dict
430
262
431
-
432
- def save_dist_sharded_model (
433
- model : ModelWrapper ,
434
- model_metadata : Dict ,
435
- checkpoint : str ,
436
- prefix : Optional [str ] = None ,
437
- size_per_shard : int = 1024 ,
438
- use_safetensors : bool = False ,
439
- use_async : bool = False ,
440
- dist_id : int = 0 ,
441
- pinned_state_dicts = None ,
442
- ) -> None :
443
- """
444
- Save sharded model checkpoint under the given checkpointing path.
445
- The following files will be created under the path:
446
- - An index file (pytorch_model.bin.index.json) containing a map between model params/buffers and file names.
447
- - Multiple files that store state tensors of models.
448
- If pipeline parallelism is used, the filenames are in the form of "pytorch_model.<prefix>-stage-000XX-shard-000XX.bin".
449
- If pipeline parallelism is not used, "pytorch_model.<prefix>-000XX.bin"
450
-
451
-
452
- Args:
453
- model (nn.Module): Model on local device to be saved.
454
- checkpoint (str): Checkpointing path which should be a directory path.
455
- gather_dtensor (bool, optional): Whether to gather_dtensor, currently not used. Defaults to True.
456
- prefix (str, optional): Perfix of file to save. Defaults to None.
457
- size_per_shard (int, optional): Size per shard in MB. Defaults to 1024.
458
- use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
459
- use_async (bool, optional): Whether to save the state_dicts of model asynchronously. Defaults to False.
460
- """
461
-
462
- model = model .unwrap ()
463
-
464
- if os .path .isfile (checkpoint ):
465
- logging .error (f"Provided path ({ checkpoint } ) should be a directory, not a file" )
466
- return
467
-
468
- Path (checkpoint ).mkdir (parents = True , exist_ok = True )
469
- # Devices along the same dp_group share the same copies of model.
470
- # So only let the device with dp_rank == 0 and sp_rank == 0 save the model.
471
-
472
- if use_async :
473
- if id (model ) not in pinned_state_dicts :
474
- pinned_state_dicts [id (model )] = {}
475
- pinned_state_dicts = pinned_state_dicts [id (model )]
476
- else :
477
- pinned_state_dicts = None
478
- state_dict_shard = dist_model_sharder (model , size_per_shard = size_per_shard , pinned_state_dicts = pinned_state_dicts )
479
- weights_name , _ = get_model_base_filenames (prefix , use_safetensors )
480
- index_file = CheckpointIndexFile (checkpoint )
481
-
482
- # Manage filenames of sharded weights and index file for each pipeline stage.
263
+ def get_dist_files_name (weights_name , dist_id ):
483
264
weights_name = weights_name .replace (".bin" , f"-dist-{ dist_id :05d} -shard.bin" )
484
265
weights_name = weights_name .replace (".safetensors" , f"-dist-{ dist_id :05d} -shard.safetensors" )
485
- metadata_file = os .path .join (checkpoint , f"{ MODEL_META_PREFIX } { dist_id :05d} { SHARD_META_SUFFIX } " )
486
- async_writers = []
487
- if use_async :
488
- total_size , writers = async_save_state_dict_shards (
489
- sharded_state_dict = state_dict_shard ,
490
- checkpoint = checkpoint ,
491
- index_file = index_file ,
492
- base_filename = weights_name ,
493
- is_master = True ,
494
- state_preprocess = False ,
495
- )
496
- async_writers .extend (writers )
497
- else :
498
- total_size = save_state_dict_shards (
499
- sharded_state_dict = state_dict_shard ,
500
- checkpoint = checkpoint ,
501
- index_file = index_file ,
502
- base_filename = weights_name ,
503
- is_master = True ,
504
- use_safetensors = use_safetensors ,
505
- use_pp_format = True ,
506
- )
507
- for k , _ in model_metadata .items ():
508
- model_metadata [k ]["file" ] = index_file .get_checkpoint_file (k )
266
+ return weights_name
509
267
510
- save_metadata (model_metadata , metadata_file , total_size = total_size )
511
- return async_writers
268
+ def get_dist_meta_file_name (checkpoint , dist_id , use_safetensors ):
269
+ if use_safetensors :
270
+ return os .path .join (checkpoint , f"{ MODEL_META_PREFIX } { dist_id :05d} { SHARD_META_SUFFIX } " )
271
+ return os .path .join (checkpoint , f"{ MODEL_META_PREFIX } { dist_id :05d} { UNSHARD_META_SUFFIX } " )
0 commit comments