52
52
)
53
53
from torchrec .modules .embedding_configs import (
54
54
BaseEmbeddingConfig ,
55
+ DataType ,
55
56
EmbeddingBagConfig ,
56
57
PoolingType ,
57
58
)
@@ -240,7 +241,7 @@ def test_sharding_fused_ebc_as_top_level(self) -> None:
240
241
self .assertTrue (isinstance (model .module , ShardedFusedEmbeddingBagCollection ))
241
242
242
243
243
- class ModelParallelStateDictBase (unittest .TestCase ):
244
+ class ModelParallelSingleRankBase (unittest .TestCase ):
244
245
def setUp (self , backend : str = "nccl" ) -> None :
245
246
os .environ ["RANK" ] = "0"
246
247
os .environ ["WORLD_SIZE" ] = "1"
@@ -261,9 +262,79 @@ def setUp(self, backend: str = "nccl") -> None:
261
262
262
263
dist .init_process_group (backend = backend )
263
264
265
+ def tearDown (self ) -> None :
266
+ dist .destroy_process_group ()
267
+ del os .environ ["NCCL_SOCKET_IFNAME" ]
268
+ super ().tearDown ()
269
+
270
+ def _train_models (
271
+ self ,
272
+ m1 : DistributedModelParallel ,
273
+ m2 : DistributedModelParallel ,
274
+ batch : ModelInput ,
275
+ ) -> None :
276
+ loss1 , pred1 = m1 (batch )
277
+ loss2 , pred2 = m2 (batch )
278
+ loss1 .backward ()
279
+ loss2 .backward ()
280
+
281
+ def _eval_models (
282
+ self ,
283
+ m1 : DistributedModelParallel ,
284
+ m2 : DistributedModelParallel ,
285
+ batch : ModelInput ,
286
+ is_deterministic : bool = True ,
287
+ ) -> None :
288
+ with torch .no_grad ():
289
+ loss1 , pred1 = m1 (batch )
290
+ loss2 , pred2 = m2 (batch )
291
+
292
+ if is_deterministic :
293
+ self .assertTrue (torch .equal (loss1 , loss2 ))
294
+ self .assertTrue (torch .equal (pred1 , pred2 ))
295
+ else :
296
+ rtol , atol = _get_default_rtol_and_atol (loss1 , loss2 )
297
+ torch .testing .assert_close (loss1 , loss2 , rtol = rtol , atol = atol )
298
+ rtol , atol = _get_default_rtol_and_atol (pred1 , pred2 )
299
+ torch .testing .assert_close (pred1 , pred2 , rtol = rtol , atol = atol )
300
+
301
+ def _compare_models (
302
+ self ,
303
+ m1 : DistributedModelParallel ,
304
+ m2 : DistributedModelParallel ,
305
+ is_deterministic : bool = True ,
306
+ ) -> None :
307
+ sd1 = m1 .state_dict ()
308
+ for key , value in m2 .state_dict ().items ():
309
+ v2 = sd1 [key ]
310
+ if isinstance (value , ShardedTensor ):
311
+ assert isinstance (v2 , ShardedTensor )
312
+ self .assertEqual (len (value .local_shards ()), len (v2 .local_shards ()))
313
+ for dst , src in zip (value .local_shards (), v2 .local_shards ()):
314
+ if is_deterministic :
315
+ self .assertTrue (torch .equal (src .tensor , dst .tensor ))
316
+ else :
317
+ rtol , atol = _get_default_rtol_and_atol (src .tensor , dst .tensor )
318
+ torch .testing .assert_close (
319
+ src .tensor , dst .tensor , rtol = rtol , atol = atol
320
+ )
321
+ else :
322
+ dst = value
323
+ src = v2
324
+ if is_deterministic :
325
+ self .assertTrue (torch .equal (src , dst ))
326
+ else :
327
+ rtol , atol = _get_default_rtol_and_atol (src , dst )
328
+ torch .testing .assert_close (src , dst , rtol = rtol , atol = atol )
329
+
330
+
331
+ class ModelParallelStateDictBase (ModelParallelSingleRankBase ):
332
+ def setUp (self , backend : str = "nccl" ) -> None :
333
+ super ().setUp (backend = backend )
334
+
264
335
num_features = 4
265
336
num_weighted_features = 2
266
- self .batch_size = 3
337
+ self .batch_size = 20
267
338
self .num_float_features = 10
268
339
269
340
self .tables = [
@@ -285,11 +356,6 @@ def setUp(self, backend: str = "nccl") -> None:
285
356
for i in range (num_weighted_features )
286
357
]
287
358
288
- def tearDown (self ) -> None :
289
- dist .destroy_process_group ()
290
- del os .environ ["NCCL_SOCKET_IFNAME" ]
291
- super ().tearDown ()
292
-
293
359
def _generate_dmps_and_batch (
294
360
self ,
295
361
sharders : Optional [List [ModuleSharder [nn .Module ]]] = None ,
@@ -352,6 +418,13 @@ def _generate_dmps_and_batch(
352
418
dmps .append (dmp )
353
419
return (dmps , batch )
354
420
421
+ def _set_table_weights_precision (self , dtype : DataType ) -> None :
422
+ for table in self .tables :
423
+ table .data_type = dtype
424
+
425
+ for weighted_table in self .weighted_tables :
426
+ weighted_table .data_type = dtype
427
+
355
428
def test_parameter_init (self ) -> None :
356
429
class MyModel (nn .Module ):
357
430
def __init__ (self , device : str , val : float ) -> None :
@@ -513,33 +586,9 @@ def test_load_state_dict(
513
586
514
587
# validate the models are equivalent
515
588
if is_training :
516
- for _ in range (2 ):
517
- loss1 , pred1 = m1 (batch )
518
- loss2 , pred2 = m2 (batch )
519
- loss1 .backward ()
520
- loss2 .backward ()
521
- self .assertTrue (torch .equal (loss1 , loss2 ))
522
- self .assertTrue (torch .equal (pred1 , pred2 ))
523
- else :
524
- with torch .no_grad ():
525
- loss1 , pred1 = m1 (batch )
526
- loss2 , pred2 = m2 (batch )
527
- self .assertTrue (torch .equal (loss1 , loss2 ))
528
- self .assertTrue (torch .equal (pred1 , pred2 ))
529
- sd1 = m1 .state_dict ()
530
- for key , value in m2 .state_dict ().items ():
531
- v2 = sd1 [key ]
532
- if isinstance (value , ShardedTensor ):
533
- assert len (value .local_shards ()) == 1
534
- dst = value .local_shards ()[0 ].tensor
535
- else :
536
- dst = value
537
- if isinstance (v2 , ShardedTensor ):
538
- assert len (v2 .local_shards ()) == 1
539
- src = v2 .local_shards ()[0 ].tensor
540
- else :
541
- src = v2
542
- self .assertTrue (torch .equal (src , dst ))
589
+ self ._train_models (m1 , m2 , batch )
590
+ self ._eval_models (m1 , m2 , batch )
591
+ self ._compare_models (m1 , m2 )
543
592
544
593
# pyre-ignore[56]
545
594
@given (
@@ -582,33 +631,9 @@ def test_load_state_dict_dp(
582
631
583
632
# validate the models are equivalent
584
633
if is_training :
585
- for _ in range (2 ):
586
- loss1 , pred1 = m1 (batch )
587
- loss2 , pred2 = m2 (batch )
588
- loss1 .backward ()
589
- loss2 .backward ()
590
- self .assertTrue (torch .equal (loss1 , loss2 ))
591
- self .assertTrue (torch .equal (pred1 , pred2 ))
592
- else :
593
- with torch .no_grad ():
594
- loss1 , pred1 = m1 (batch )
595
- loss2 , pred2 = m2 (batch )
596
- self .assertTrue (torch .equal (loss1 , loss2 ))
597
- self .assertTrue (torch .equal (pred1 , pred2 ))
598
- sd1 = m1 .state_dict ()
599
- for key , value in m2 .state_dict ().items ():
600
- v2 = sd1 [key ]
601
- if isinstance (value , ShardedTensor ):
602
- assert len (value .local_shards ()) == 1
603
- dst = value .local_shards ()[0 ].tensor
604
- else :
605
- dst = value
606
- if isinstance (v2 , ShardedTensor ):
607
- assert len (v2 .local_shards ()) == 1
608
- src = v2 .local_shards ()[0 ].tensor
609
- else :
610
- src = v2
611
- self .assertTrue (torch .equal (src , dst ))
634
+ self ._train_models (m1 , m2 , batch )
635
+ self ._eval_models (m1 , m2 , batch )
636
+ self ._compare_models (m1 , m2 )
612
637
613
638
# pyre-ignore[56]
614
639
@given (
@@ -661,34 +686,9 @@ def test_load_state_dict_prefix(
661
686
662
687
# validate the models are equivalent
663
688
if is_training :
664
- for _ in range (2 ):
665
- loss1 , pred1 = m1 (batch )
666
- loss2 , pred2 = m2 (batch )
667
- loss1 .backward ()
668
- loss2 .backward ()
669
- self .assertTrue (torch .equal (loss1 , loss2 ))
670
- self .assertTrue (torch .equal (pred1 , pred2 ))
671
- else :
672
- with torch .no_grad ():
673
- loss1 , pred1 = m1 (batch )
674
- loss2 , pred2 = m2 (batch )
675
- self .assertTrue (torch .equal (loss1 , loss2 ))
676
- self .assertTrue (torch .equal (pred1 , pred2 ))
677
-
678
- sd1 = m1 .state_dict ()
679
- for key , value in m2 .state_dict ().items ():
680
- v2 = sd1 [key ]
681
- if isinstance (value , ShardedTensor ):
682
- assert len (value .local_shards ()) == 1
683
- dst = value .local_shards ()[0 ].tensor
684
- else :
685
- dst = value
686
- if isinstance (v2 , ShardedTensor ):
687
- assert len (v2 .local_shards ()) == 1
688
- src = v2 .local_shards ()[0 ].tensor
689
- else :
690
- src = v2
691
- self .assertTrue (torch .equal (src , dst ))
689
+ self ._train_models (m1 , m2 , batch )
690
+ self ._eval_models (m1 , m2 , batch )
691
+ self ._compare_models (m1 , m2 )
692
692
693
693
# pyre-fixme[56]
694
694
@given (
@@ -807,19 +807,8 @@ def test_load_state_dict_cw_multiple_shards(
807
807
808
808
# validate the models are equivalent
809
809
if is_training :
810
- for _ in range (2 ):
811
- loss1 , pred1 = m1 (batch )
812
- loss2 , pred2 = m2 (batch )
813
- loss1 .backward ()
814
- loss2 .backward ()
815
- self .assertTrue (torch .equal (loss1 , loss2 ))
816
- self .assertTrue (torch .equal (pred1 , pred2 ))
817
- else :
818
- with torch .no_grad ():
819
- loss1 , pred1 = m1 (batch )
820
- loss2 , pred2 = m2 (batch )
821
- self .assertTrue (torch .equal (loss1 , loss2 ))
822
- self .assertTrue (torch .equal (pred1 , pred2 ))
810
+ self ._train_models (m1 , m2 , batch )
811
+ self ._eval_models (m1 , m2 , batch )
823
812
824
813
sd1 = m1 .state_dict ()
825
814
for key , value in m2 .state_dict ().items ():
@@ -876,3 +865,89 @@ def test_load_state_dict_cw_multiple_shards(
876
865
)
877
866
elif isinstance (dst_opt_state , torch .Tensor ):
878
867
self .assertIsInstance (src_opt_state , torch .Tensor )
868
+
869
+ @unittest .skipIf (
870
+ not torch .cuda .is_available (),
871
+ "Not enough GPUs, this test requires at least one GPU" ,
872
+ )
873
+ # pyre-ignore[56]
874
+ @given (
875
+ sharder_type = st .sampled_from (
876
+ [
877
+ SharderType .EMBEDDING_BAG_COLLECTION .value ,
878
+ ]
879
+ ),
880
+ sharding_type = st .sampled_from (
881
+ [
882
+ ShardingType .TABLE_WISE .value ,
883
+ ShardingType .COLUMN_WISE .value ,
884
+ ShardingType .ROW_WISE .value ,
885
+ ShardingType .TABLE_ROW_WISE .value ,
886
+ ShardingType .TABLE_COLUMN_WISE .value ,
887
+ ]
888
+ ),
889
+ kernel_type = st .sampled_from (
890
+ [
891
+ EmbeddingComputeKernel .FUSED_UVM_CACHING .value ,
892
+ EmbeddingComputeKernel .FUSED_UVM .value ,
893
+ ]
894
+ ),
895
+ is_training = st .booleans (),
896
+ stochastic_rounding = st .booleans (),
897
+ dtype = st .sampled_from ([DataType .FP32 , DataType .FP16 ]),
898
+ )
899
+ @settings (verbosity = Verbosity .verbose , max_examples = 4 , deadline = None )
900
+ def test_numerical_equivalence_between_kernel_types (
901
+ self ,
902
+ sharder_type : str ,
903
+ sharding_type : str ,
904
+ kernel_type : str ,
905
+ is_training : bool ,
906
+ stochastic_rounding : bool ,
907
+ dtype : DataType ,
908
+ ) -> None :
909
+ self ._set_table_weights_precision (dtype )
910
+ fused_params = {
911
+ "stochastic_rounding" : stochastic_rounding ,
912
+ "cache_precision" : dtype ,
913
+ }
914
+
915
+ fused_sharders = [
916
+ cast (
917
+ ModuleSharder [nn .Module ],
918
+ create_test_sharder (
919
+ sharder_type ,
920
+ sharding_type ,
921
+ EmbeddingComputeKernel .FUSED .value ,
922
+ fused_params = fused_params ,
923
+ ),
924
+ ),
925
+ ]
926
+ sharders = [
927
+ cast (
928
+ ModuleSharder [nn .Module ],
929
+ create_test_sharder (
930
+ sharder_type ,
931
+ sharding_type ,
932
+ kernel_type ,
933
+ fused_params = fused_params ,
934
+ ),
935
+ ),
936
+ ]
937
+ (fused_model , _ ), _ = self ._generate_dmps_and_batch (fused_sharders )
938
+ (model , _ ), batch = self ._generate_dmps_and_batch (sharders )
939
+
940
+ # load the baseline model's state_dict onto the new model
941
+ model .load_state_dict (
942
+ cast ("OrderedDict[str, torch.Tensor]" , fused_model .state_dict ())
943
+ )
944
+
945
+ if is_training :
946
+ for _ in range (4 ):
947
+ self ._train_models (fused_model , model , batch )
948
+ self ._eval_models (
949
+ fused_model , model , batch , is_deterministic = not stochastic_rounding
950
+ )
951
+ self ._compare_models (
952
+ fused_model , model , is_deterministic = not stochastic_rounding
953
+ )
0 commit comments