Skip to content

Commit e719551

Browse files
henrylhtsangfacebook-github-bot
authored andcommitted
Add numerical equivalence tests for model_parallel (#1721)
Summary: Pull Request resolved: #1721 Adding numerical equivalence tests for different kernel types. Note that when stochastic rounding = True + is_training = True is always a bit tricky. However, from what FBGEMM team said, it is not guaranteed to have iteration-wise equivalence, but it's guaranteed to have numerical equivalence in the long term. Reviewed By: sarckk Differential Revision: D53878573 fbshipit-source-id: 5b0eb4f1977a448f27db80e9bd19b02cb58b15bb
1 parent d740e1f commit e719551

4 files changed

+372
-214
lines changed

torchrec/distributed/test_utils/test_model_parallel_base.py

+177-102
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
)
5353
from torchrec.modules.embedding_configs import (
5454
BaseEmbeddingConfig,
55+
DataType,
5556
EmbeddingBagConfig,
5657
PoolingType,
5758
)
@@ -240,7 +241,7 @@ def test_sharding_fused_ebc_as_top_level(self) -> None:
240241
self.assertTrue(isinstance(model.module, ShardedFusedEmbeddingBagCollection))
241242

242243

243-
class ModelParallelStateDictBase(unittest.TestCase):
244+
class ModelParallelSingleRankBase(unittest.TestCase):
244245
def setUp(self, backend: str = "nccl") -> None:
245246
os.environ["RANK"] = "0"
246247
os.environ["WORLD_SIZE"] = "1"
@@ -261,9 +262,79 @@ def setUp(self, backend: str = "nccl") -> None:
261262

262263
dist.init_process_group(backend=backend)
263264

265+
def tearDown(self) -> None:
266+
dist.destroy_process_group()
267+
del os.environ["NCCL_SOCKET_IFNAME"]
268+
super().tearDown()
269+
270+
def _train_models(
271+
self,
272+
m1: DistributedModelParallel,
273+
m2: DistributedModelParallel,
274+
batch: ModelInput,
275+
) -> None:
276+
loss1, pred1 = m1(batch)
277+
loss2, pred2 = m2(batch)
278+
loss1.backward()
279+
loss2.backward()
280+
281+
def _eval_models(
282+
self,
283+
m1: DistributedModelParallel,
284+
m2: DistributedModelParallel,
285+
batch: ModelInput,
286+
is_deterministic: bool = True,
287+
) -> None:
288+
with torch.no_grad():
289+
loss1, pred1 = m1(batch)
290+
loss2, pred2 = m2(batch)
291+
292+
if is_deterministic:
293+
self.assertTrue(torch.equal(loss1, loss2))
294+
self.assertTrue(torch.equal(pred1, pred2))
295+
else:
296+
rtol, atol = _get_default_rtol_and_atol(loss1, loss2)
297+
torch.testing.assert_close(loss1, loss2, rtol=rtol, atol=atol)
298+
rtol, atol = _get_default_rtol_and_atol(pred1, pred2)
299+
torch.testing.assert_close(pred1, pred2, rtol=rtol, atol=atol)
300+
301+
def _compare_models(
302+
self,
303+
m1: DistributedModelParallel,
304+
m2: DistributedModelParallel,
305+
is_deterministic: bool = True,
306+
) -> None:
307+
sd1 = m1.state_dict()
308+
for key, value in m2.state_dict().items():
309+
v2 = sd1[key]
310+
if isinstance(value, ShardedTensor):
311+
assert isinstance(v2, ShardedTensor)
312+
self.assertEqual(len(value.local_shards()), len(v2.local_shards()))
313+
for dst, src in zip(value.local_shards(), v2.local_shards()):
314+
if is_deterministic:
315+
self.assertTrue(torch.equal(src.tensor, dst.tensor))
316+
else:
317+
rtol, atol = _get_default_rtol_and_atol(src.tensor, dst.tensor)
318+
torch.testing.assert_close(
319+
src.tensor, dst.tensor, rtol=rtol, atol=atol
320+
)
321+
else:
322+
dst = value
323+
src = v2
324+
if is_deterministic:
325+
self.assertTrue(torch.equal(src, dst))
326+
else:
327+
rtol, atol = _get_default_rtol_and_atol(src, dst)
328+
torch.testing.assert_close(src, dst, rtol=rtol, atol=atol)
329+
330+
331+
class ModelParallelStateDictBase(ModelParallelSingleRankBase):
332+
def setUp(self, backend: str = "nccl") -> None:
333+
super().setUp(backend=backend)
334+
264335
num_features = 4
265336
num_weighted_features = 2
266-
self.batch_size = 3
337+
self.batch_size = 20
267338
self.num_float_features = 10
268339

269340
self.tables = [
@@ -285,11 +356,6 @@ def setUp(self, backend: str = "nccl") -> None:
285356
for i in range(num_weighted_features)
286357
]
287358

288-
def tearDown(self) -> None:
289-
dist.destroy_process_group()
290-
del os.environ["NCCL_SOCKET_IFNAME"]
291-
super().tearDown()
292-
293359
def _generate_dmps_and_batch(
294360
self,
295361
sharders: Optional[List[ModuleSharder[nn.Module]]] = None,
@@ -352,6 +418,13 @@ def _generate_dmps_and_batch(
352418
dmps.append(dmp)
353419
return (dmps, batch)
354420

421+
def _set_table_weights_precision(self, dtype: DataType) -> None:
422+
for table in self.tables:
423+
table.data_type = dtype
424+
425+
for weighted_table in self.weighted_tables:
426+
weighted_table.data_type = dtype
427+
355428
def test_parameter_init(self) -> None:
356429
class MyModel(nn.Module):
357430
def __init__(self, device: str, val: float) -> None:
@@ -513,33 +586,9 @@ def test_load_state_dict(
513586

514587
# validate the models are equivalent
515588
if is_training:
516-
for _ in range(2):
517-
loss1, pred1 = m1(batch)
518-
loss2, pred2 = m2(batch)
519-
loss1.backward()
520-
loss2.backward()
521-
self.assertTrue(torch.equal(loss1, loss2))
522-
self.assertTrue(torch.equal(pred1, pred2))
523-
else:
524-
with torch.no_grad():
525-
loss1, pred1 = m1(batch)
526-
loss2, pred2 = m2(batch)
527-
self.assertTrue(torch.equal(loss1, loss2))
528-
self.assertTrue(torch.equal(pred1, pred2))
529-
sd1 = m1.state_dict()
530-
for key, value in m2.state_dict().items():
531-
v2 = sd1[key]
532-
if isinstance(value, ShardedTensor):
533-
assert len(value.local_shards()) == 1
534-
dst = value.local_shards()[0].tensor
535-
else:
536-
dst = value
537-
if isinstance(v2, ShardedTensor):
538-
assert len(v2.local_shards()) == 1
539-
src = v2.local_shards()[0].tensor
540-
else:
541-
src = v2
542-
self.assertTrue(torch.equal(src, dst))
589+
self._train_models(m1, m2, batch)
590+
self._eval_models(m1, m2, batch)
591+
self._compare_models(m1, m2)
543592

544593
# pyre-ignore[56]
545594
@given(
@@ -582,33 +631,9 @@ def test_load_state_dict_dp(
582631

583632
# validate the models are equivalent
584633
if is_training:
585-
for _ in range(2):
586-
loss1, pred1 = m1(batch)
587-
loss2, pred2 = m2(batch)
588-
loss1.backward()
589-
loss2.backward()
590-
self.assertTrue(torch.equal(loss1, loss2))
591-
self.assertTrue(torch.equal(pred1, pred2))
592-
else:
593-
with torch.no_grad():
594-
loss1, pred1 = m1(batch)
595-
loss2, pred2 = m2(batch)
596-
self.assertTrue(torch.equal(loss1, loss2))
597-
self.assertTrue(torch.equal(pred1, pred2))
598-
sd1 = m1.state_dict()
599-
for key, value in m2.state_dict().items():
600-
v2 = sd1[key]
601-
if isinstance(value, ShardedTensor):
602-
assert len(value.local_shards()) == 1
603-
dst = value.local_shards()[0].tensor
604-
else:
605-
dst = value
606-
if isinstance(v2, ShardedTensor):
607-
assert len(v2.local_shards()) == 1
608-
src = v2.local_shards()[0].tensor
609-
else:
610-
src = v2
611-
self.assertTrue(torch.equal(src, dst))
634+
self._train_models(m1, m2, batch)
635+
self._eval_models(m1, m2, batch)
636+
self._compare_models(m1, m2)
612637

613638
# pyre-ignore[56]
614639
@given(
@@ -661,34 +686,9 @@ def test_load_state_dict_prefix(
661686

662687
# validate the models are equivalent
663688
if is_training:
664-
for _ in range(2):
665-
loss1, pred1 = m1(batch)
666-
loss2, pred2 = m2(batch)
667-
loss1.backward()
668-
loss2.backward()
669-
self.assertTrue(torch.equal(loss1, loss2))
670-
self.assertTrue(torch.equal(pred1, pred2))
671-
else:
672-
with torch.no_grad():
673-
loss1, pred1 = m1(batch)
674-
loss2, pred2 = m2(batch)
675-
self.assertTrue(torch.equal(loss1, loss2))
676-
self.assertTrue(torch.equal(pred1, pred2))
677-
678-
sd1 = m1.state_dict()
679-
for key, value in m2.state_dict().items():
680-
v2 = sd1[key]
681-
if isinstance(value, ShardedTensor):
682-
assert len(value.local_shards()) == 1
683-
dst = value.local_shards()[0].tensor
684-
else:
685-
dst = value
686-
if isinstance(v2, ShardedTensor):
687-
assert len(v2.local_shards()) == 1
688-
src = v2.local_shards()[0].tensor
689-
else:
690-
src = v2
691-
self.assertTrue(torch.equal(src, dst))
689+
self._train_models(m1, m2, batch)
690+
self._eval_models(m1, m2, batch)
691+
self._compare_models(m1, m2)
692692

693693
# pyre-fixme[56]
694694
@given(
@@ -807,19 +807,8 @@ def test_load_state_dict_cw_multiple_shards(
807807

808808
# validate the models are equivalent
809809
if is_training:
810-
for _ in range(2):
811-
loss1, pred1 = m1(batch)
812-
loss2, pred2 = m2(batch)
813-
loss1.backward()
814-
loss2.backward()
815-
self.assertTrue(torch.equal(loss1, loss2))
816-
self.assertTrue(torch.equal(pred1, pred2))
817-
else:
818-
with torch.no_grad():
819-
loss1, pred1 = m1(batch)
820-
loss2, pred2 = m2(batch)
821-
self.assertTrue(torch.equal(loss1, loss2))
822-
self.assertTrue(torch.equal(pred1, pred2))
810+
self._train_models(m1, m2, batch)
811+
self._eval_models(m1, m2, batch)
823812

824813
sd1 = m1.state_dict()
825814
for key, value in m2.state_dict().items():
@@ -876,3 +865,89 @@ def test_load_state_dict_cw_multiple_shards(
876865
)
877866
elif isinstance(dst_opt_state, torch.Tensor):
878867
self.assertIsInstance(src_opt_state, torch.Tensor)
868+
869+
@unittest.skipIf(
870+
not torch.cuda.is_available(),
871+
"Not enough GPUs, this test requires at least one GPU",
872+
)
873+
# pyre-ignore[56]
874+
@given(
875+
sharder_type=st.sampled_from(
876+
[
877+
SharderType.EMBEDDING_BAG_COLLECTION.value,
878+
]
879+
),
880+
sharding_type=st.sampled_from(
881+
[
882+
ShardingType.TABLE_WISE.value,
883+
ShardingType.COLUMN_WISE.value,
884+
ShardingType.ROW_WISE.value,
885+
ShardingType.TABLE_ROW_WISE.value,
886+
ShardingType.TABLE_COLUMN_WISE.value,
887+
]
888+
),
889+
kernel_type=st.sampled_from(
890+
[
891+
EmbeddingComputeKernel.FUSED_UVM_CACHING.value,
892+
EmbeddingComputeKernel.FUSED_UVM.value,
893+
]
894+
),
895+
is_training=st.booleans(),
896+
stochastic_rounding=st.booleans(),
897+
dtype=st.sampled_from([DataType.FP32, DataType.FP16]),
898+
)
899+
@settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None)
900+
def test_numerical_equivalence_between_kernel_types(
901+
self,
902+
sharder_type: str,
903+
sharding_type: str,
904+
kernel_type: str,
905+
is_training: bool,
906+
stochastic_rounding: bool,
907+
dtype: DataType,
908+
) -> None:
909+
self._set_table_weights_precision(dtype)
910+
fused_params = {
911+
"stochastic_rounding": stochastic_rounding,
912+
"cache_precision": dtype,
913+
}
914+
915+
fused_sharders = [
916+
cast(
917+
ModuleSharder[nn.Module],
918+
create_test_sharder(
919+
sharder_type,
920+
sharding_type,
921+
EmbeddingComputeKernel.FUSED.value,
922+
fused_params=fused_params,
923+
),
924+
),
925+
]
926+
sharders = [
927+
cast(
928+
ModuleSharder[nn.Module],
929+
create_test_sharder(
930+
sharder_type,
931+
sharding_type,
932+
kernel_type,
933+
fused_params=fused_params,
934+
),
935+
),
936+
]
937+
(fused_model, _), _ = self._generate_dmps_and_batch(fused_sharders)
938+
(model, _), batch = self._generate_dmps_and_batch(sharders)
939+
940+
# load the baseline model's state_dict onto the new model
941+
model.load_state_dict(
942+
cast("OrderedDict[str, torch.Tensor]", fused_model.state_dict())
943+
)
944+
945+
if is_training:
946+
for _ in range(4):
947+
self._train_models(fused_model, model, batch)
948+
self._eval_models(
949+
fused_model, model, batch, is_deterministic=not stochastic_rounding
950+
)
951+
self._compare_models(
952+
fused_model, model, is_deterministic=not stochastic_rounding
953+
)

0 commit comments

Comments
 (0)