diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py index 03aa3ea96..53b4b3b35 100644 --- a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py @@ -46,6 +46,7 @@ from torchrec.distributed.tests.test_fp_embeddingbag_utils import ( create_module_and_freeze, ) +from torchrec.distributed.train_pipeline import TorchCompileConfig from torchrec.distributed.train_pipeline.tests.test_train_pipelines_base import ( TrainPipelineSparseDistTestBase, ) @@ -134,6 +135,7 @@ def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool class TrainPipelineBaseTest(unittest.TestCase): def setUp(self) -> None: self.device = torch.device("cuda:0") + self.optimizer_compile_config = TorchCompileConfig() torch.backends.cudnn.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False @@ -156,7 +158,43 @@ def test_equal_to_non_pipelined(self) -> None: for b in range(5) ] dataloader = iter(data) - pipeline = TrainPipelineBase(model_gpu, optimizer_gpu, self.device) + pipeline = TrainPipelineBase( + model_gpu, optimizer=optimizer_gpu, device=self.device + ) + + for batch in data[:-1]: + optimizer_cpu.zero_grad() + loss, pred = model_cpu(batch) + loss.backward() + optimizer_cpu.step() + + pred_gpu = pipeline.progress(dataloader) + + self.assertEqual(pred_gpu.device, self.device) + # Results will be close but not exactly equal as one model is on CPU and other on GPU + # If both were on GPU, the results will be exactly the same + self.assertTrue(torch.isclose(pred_gpu.cpu(), pred)) + + def test_equal_to_non_pipelined_compiled(self) -> None: + model_cpu = TestModule() + model_gpu = TestModule().to(self.device) + model_gpu.load_state_dict(model_cpu.state_dict()) + optimizer_cpu = optim.SGD(model_cpu.model.parameters(), lr=0.01) + optimizer_gpu = optim.SGD(model_gpu.model.parameters(), lr=0.01) + data = [ + ModelInputSimple( + float_features=torch.rand((10,)), + label=torch.randint(2, (1,), dtype=torch.float32), + ) + for b in range(5) + ] + dataloader = iter(data) + pipeline = TrainPipelineBase( + model=model_gpu, + optimizer=optimizer_gpu, + device=self.device, + optimizer_compile_config=self.optimizer_compile_config, + ) for batch in data[:-1]: optimizer_cpu.zero_grad() @@ -175,6 +213,7 @@ def test_equal_to_non_pipelined(self) -> None: class TrainPipelinePT2Test(unittest.TestCase): def setUp(self) -> None: self.device = torch.device("cuda:0") + self.optimizer_compile_config = TorchCompileConfig() torch.backends.cudnn.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False @@ -234,7 +273,41 @@ def test_equal_to_non_pipelined(self) -> None: for b in range(5) ] dataloader = iter(data) - pipeline = TrainPipelinePT2(model_gpu, optimizer_gpu, self.device) + pipeline = TrainPipelinePT2( + model_gpu, optimizer=optimizer_gpu, device=self.device + ) + + for batch in data[:-1]: + optimizer_cpu.zero_grad() + loss, pred = model_cpu(batch) + loss.backward() + optimizer_cpu.step() + + pred_gpu = pipeline.progress(dataloader) + + self.assertEqual(pred_gpu.device, self.device) + self.assertTrue(torch.isclose(pred_gpu.cpu(), pred)) + + def test_equal_to_non_pipelined_compiled(self) -> None: + model_cpu = TestModule() + model_gpu = TestModule().to(self.device) + model_gpu.load_state_dict(model_cpu.state_dict()) + optimizer_cpu = optim.SGD(model_cpu.model.parameters(), lr=0.01) + optimizer_gpu = optim.SGD(model_gpu.model.parameters(), lr=0.01) + data = [ + ModelInputSimple( + float_features=torch.rand((10,)), + label=torch.randint(2, (1,), dtype=torch.float32), + ) + for b in range(5) + ] + dataloader = iter(data) + pipeline = TrainPipelinePT2( + model=model_gpu, + optimizer=optimizer_gpu, + device=self.device, + optimizer_compile_config=self.optimizer_compile_config, + ) for batch in data[:-1]: optimizer_cpu.zero_grad() @@ -271,7 +344,10 @@ def pre_compile_fn(model: nn.Module) -> None: dataloader = iter(data) pipeline = TrainPipelinePT2( - model_gpu, optimizer_gpu, self.device, pre_compile_fn=pre_compile_fn + model=model_gpu, + optimizer=optimizer_gpu, + device=self.device, + pre_compile_fn=pre_compile_fn, ) self.assertEqual(model_gpu._dummy_setting, "dummy") for _ in range(len(data)): @@ -315,7 +391,10 @@ def test_equal_to_non_pipelined_with_input_transformer(self) -> None: ] dataloader = iter(data) pipeline = TrainPipelinePT2( - model_gpu, optimizer_gpu, self.device, input_transformer=kjt_for_pt2_tracing + model=model_gpu, + optimizer=optimizer_gpu, + device=self.device, + input_transformer=kjt_for_pt2_tracing, ) for batch in data[:-1]: @@ -545,6 +624,83 @@ def test_equal_to_non_pipelined( self.assertRaises(StopIteration, pipeline.progress, dataloader) + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + @settings(max_examples=4, deadline=None) + # pyre-ignore[56] + @given( + sharding_type=st.sampled_from( + [ + ShardingType.TABLE_WISE.value, + ShardingType.COLUMN_WISE.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + ] + ), + execute_all_batches=st.booleans(), + ) + def test_equal_to_non_pipelined_compiled( + self, + sharding_type: str, + kernel_type: str, + execute_all_batches: bool, + ) -> None: + """ + Checks that pipelined training is equivalent to non-pipelined training. + """ + data = self._generate_data( + num_batches=12, + batch_size=32, + ) + dataloader = iter(data) + + fused_params = {} + fused_params_pipelined = {} + + model = self._setup_model() + sharded_model, optim = self._generate_sharded_model_and_optimizer( + model, sharding_type, kernel_type, fused_params + ) + + ( + sharded_model_pipelined, + optim_pipelined, + ) = self._generate_sharded_model_and_optimizer( + model, sharding_type, kernel_type, fused_params_pipelined + ) + copy_state_dict( + sharded_model.state_dict(), sharded_model_pipelined.state_dict() + ) + + pipeline = self.pipeline_class( + model=sharded_model_pipelined, + optimizer=optim_pipelined, + device=self.device, + execute_all_batches=execute_all_batches, + optimizer_compile_config=self.optimizer_compile_config, + ) + if not execute_all_batches: + data = data[:-2] + + for batch in data: + # Forward + backward w/o pipelining + batch = batch.to(self.device) + optim.zero_grad() + loss, pred = sharded_model(batch) + loss.backward() + optim.step() + + # Forward + backward w/ pipelining + pred_pipeline = pipeline.progress(dataloader) + torch.testing.assert_close(pred, pred_pipeline) + + self.assertRaises(StopIteration, pipeline.progress, dataloader) + @unittest.skipIf( not torch.cuda.is_available(), "Not enough GPUs, this test requires at least one GPU", @@ -1669,6 +1825,149 @@ def test_equal_to_non_pipelined( pred_pipeline = pipeline.progress(dataloader) self.assertRaises(StopIteration, pipeline.progress, dataloader) + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + @settings(max_examples=8, deadline=None) + # pyre-ignore[56] + @given( + start_batch=st.sampled_from([0, 6]), + stash_gradients=st.booleans(), + sharding_type=st.sampled_from( + [ + ShardingType.TABLE_WISE.value, + ShardingType.ROW_WISE.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + ] + ), + zch=st.booleans(), + ) + def test_equal_to_non_pipelined_compiled( + self, + start_batch: int, + stash_gradients: bool, + sharding_type: str, + kernel_type: str, + zch: bool, + ) -> None: + """ + Checks that pipelined training is equivalent to non-pipelined training. + """ + # ZCH only supports row-wise currently + assume(not zch or (zch and sharding_type != ShardingType.TABLE_WISE.value)) + torch.autograd.set_detect_anomaly(True) + data = self._generate_data( + num_batches=12, + batch_size=32, + ) + dataloader = iter(data) + + fused_params = { + "stochastic_rounding": False, + } + fused_params_pipelined = { + **fused_params, + } + + model = self._setup_model(zch=zch) + sharded_model, optim = self._generate_sharded_model_and_optimizer( + model, sharding_type, kernel_type, fused_params + ) + + ( + sharded_model_pipelined, + optim_pipelined, + ) = self._generate_sharded_model_and_optimizer( + model, sharding_type, kernel_type, fused_params_pipelined + ) + copy_state_dict( + sharded_model.state_dict(), sharded_model_pipelined.state_dict() + ) + + pipeline = TrainPipelineSemiSync( + model=sharded_model_pipelined, + optimizer=optim_pipelined, + device=self.device, + execute_all_batches=True, + start_batch=start_batch, + stash_gradients=stash_gradients, + optimizer_compile_config=self.optimizer_compile_config, + ) + + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `sparse_forward`. + prior_sparse_out = sharded_model._dmp_wrapped_module.sparse_forward( + data[0].to(self.device) + ) + prior_batch = data[0].to(self.device) + prior_stashed_grads = None + batch_index = 0 + sparse_out = None + for batch in data[1:]: + batch_index += 1 + # Forward + backward w/o pipelining + batch = batch.to(self.device) + + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `dense_forward`. + loss, pred = sharded_model._dmp_wrapped_module.dense_forward( + prior_batch, prior_sparse_out + ) + if batch_index - 1 >= start_batch: + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no + # attribute `sparse_forward`. + sparse_out = sharded_model._dmp_wrapped_module.sparse_forward(batch) + + loss.backward() + + stashed_grads = None + if batch_index - 1 >= start_batch and stash_gradients: + stashed_grads = [] + for param in optim.param_groups[0]["params"]: + stashed_grads.append( + param.grad.clone() if param.grad is not None else None + ) + param.grad = None + + if prior_stashed_grads is not None: + for param, stashed_grad in zip( + optim.param_groups[0]["params"], prior_stashed_grads + ): + param.grad = stashed_grad + optim.step() + optim.zero_grad() + + if batch_index - 1 < start_batch: + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no + # attribute `sparse_forward`. + sparse_out = sharded_model._dmp_wrapped_module.sparse_forward(batch) + + prior_stashed_grads = stashed_grads + prior_batch = batch + prior_sparse_out = sparse_out + # Forward + backward w/ pipelining + pred_pipeline = pipeline.progress(dataloader) + + if batch_index >= start_batch: + self.assertTrue( + pipeline.is_semi_sync(), msg="pipeline is not semi_sync" + ) + else: + self.assertFalse(pipeline.is_semi_sync(), msg="pipeline is semi_sync") + self.assertTrue( + torch.equal(pred, pred_pipeline), + msg=f"batch {batch_index} doesn't match", + ) + + # one more batch + pred_pipeline = pipeline.progress(dataloader) + self.assertRaises(StopIteration, pipeline.progress, dataloader) + class PrefetchTrainPipelineSparseDistTest(TrainPipelineSparseDistTestBase): @unittest.skipIf( @@ -1783,6 +2082,119 @@ def test_equal_to_non_pipelined( else: torch.testing.assert_close(pred, pred_pipeline) + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + @settings(max_examples=4, deadline=None) + # pyre-ignore[56] + @given( + execute_all_batches=st.booleans(), + weight_precision=st.sampled_from( + [ + DataType.FP16, + DataType.FP32, + ] + ), + cache_precision=st.sampled_from( + [ + DataType.FP16, + DataType.FP32, + ] + ), + load_factor=st.sampled_from( + [ + 0.2, + 0.4, + 0.6, + ] + ), + sharding_type=st.sampled_from( + [ + ShardingType.TABLE_WISE.value, + ShardingType.ROW_WISE.value, + ShardingType.COLUMN_WISE.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + ] + ), + ) + def test_equal_to_non_pipelined_compiled( + self, + execute_all_batches: bool, + weight_precision: DataType, + cache_precision: DataType, + load_factor: float, + sharding_type: str, + kernel_type: str, + ) -> None: + """ + Checks that pipelined training is equivalent to non-pipelined training. + """ + mixed_precision: bool = weight_precision != cache_precision + self._set_table_weights_precision(weight_precision) + data = self._generate_data( + num_batches=12, + batch_size=32, + ) + dataloader = iter(data) + + fused_params = { + "cache_load_factor": load_factor, + "cache_precision": cache_precision, + "stochastic_rounding": False, # disable non-deterministic behavior when converting fp32<->fp16 + } + fused_params_pipelined = { + **fused_params, + "prefetch_pipeline": True, + } + + model = self._setup_model() + sharded_model, optim = self._generate_sharded_model_and_optimizer( + model, sharding_type, kernel_type, fused_params + ) + + ( + sharded_model_pipelined, + optim_pipelined, + ) = self._generate_sharded_model_and_optimizer( + model, sharding_type, kernel_type, fused_params_pipelined + ) + copy_state_dict( + sharded_model.state_dict(), sharded_model_pipelined.state_dict() + ) + + pipeline = PrefetchTrainPipelineSparseDist( + model=sharded_model_pipelined, + optimizer=optim_pipelined, + device=self.device, + execute_all_batches=execute_all_batches, + optimizer_compile_config=self.optimizer_compile_config, + ) + + if not execute_all_batches: + data = data[:-3] + + for batch in data: + # Forward + backward w/o pipelining + batch = batch.to(self.device) + optim.zero_grad() + loss, pred = sharded_model(batch) + loss.backward() + optim.step() + + # Forward + backward w/ pipelining + pred_pipeline = pipeline.progress(dataloader) + + if not mixed_precision: + # Rounding error is expected when using different precisions for weights and cache + self.assertTrue(torch.equal(pred, pred_pipeline)) + else: + torch.testing.assert_close(pred, pred_pipeline) + class DataLoadingThreadTest(unittest.TestCase): def test_fetch_data(self) -> None: @@ -2365,3 +2777,14 @@ def test_equal_to_non_pipelined( execute_all_batches: bool, ) -> None: super().test_equal_to_non_pipelined() + + @unittest.skip( + "TrainPipelineSparseDistTest.test_equal_to_non_pipelined_compiled was called from multiple different executors, which fails hypothesis HealthChek, so we skip it here" + ) + def test_equal_to_non_pipelined_compiled( + self, + sharding_type: str, + kernel_type: str, + execute_all_batches: bool, + ) -> None: + super().test_equal_to_non_pipelined_compiled() diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines_base.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines_base.py index 56e6ac636..a68f78295 100644 --- a/torchrec/distributed/train_pipeline/tests/test_train_pipelines_base.py +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines_base.py @@ -24,6 +24,7 @@ TestEBCSharderMCH, TestSparseNN, ) +from torchrec.distributed.train_pipeline import TorchCompileConfig from torchrec.distributed.train_pipeline.train_pipelines import TrainPipelineSparseDist from torchrec.distributed.types import ModuleSharder, ShardingEnv from torchrec.modules.embedding_configs import DataType, EmbeddingBagConfig @@ -62,6 +63,7 @@ def setUp(self) -> None: self.device = torch.device("cuda:0") self.pipeline_class = TrainPipelineSparseDist + self.optimizer_compile_config = TorchCompileConfig() def tearDown(self) -> None: super().tearDown() diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py index fcd7efc24..1311832a2 100644 --- a/torchrec/distributed/train_pipeline/train_pipelines.py +++ b/torchrec/distributed/train_pipeline/train_pipelines.py @@ -121,6 +121,7 @@ def __init__( custom_model_fwd: Optional[ Callable[[In], Tuple[torch.Tensor, List[torch.Tensor]]] ] = None, + optimizer_compile_config: Optional[TorchCompileConfig] = None, ) -> None: self._model = model self._optimizer = optimizer @@ -140,6 +141,16 @@ def __init__( self._cur_batch: Optional[In] = None self._connected = False + if optimizer_compile_config is not None: + self._optimizer_step: Callable[[], None] = torch.compile( + lambda: self._optimizer.step(), + fullgraph=optimizer_compile_config.fullgraph, + dynamic=optimizer_compile_config.dynamic, + backend=optimizer_compile_config.backend, + ) + else: + self._optimizer_step: Callable[[], None] = self._optimizer.step + def _connect(self, dataloader_iter: Iterator[In]) -> None: cur_batch = next(dataloader_iter) self._cur_batch = cur_batch @@ -193,7 +204,7 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out: # Update if self._model.training: with record_function("## optimizer ##"): - self._optimizer.step() + self._optimizer_step() return output @@ -221,6 +232,7 @@ def __init__( pre_compile_fn: Optional[Callable[[torch.nn.Module], None]] = None, post_compile_fn: Optional[Callable[[torch.nn.Module], None]] = None, input_transformer: Optional[Callable[[In], In]] = None, + optimizer_compile_config: Optional[TorchCompileConfig] = None, ) -> None: self._model = model self._optimizer = optimizer @@ -237,6 +249,16 @@ def __init__( self._iter = 0 self._cur_batch: Optional[In] = None + if optimizer_compile_config is not None: + self._optimizer_step: Callable[[], None] = torch.compile( + lambda: self._optimizer.step(), + fullgraph=optimizer_compile_config.fullgraph, + dynamic=optimizer_compile_config.dynamic, + backend=optimizer_compile_config.backend, + ) + else: + self._optimizer_step: Callable[[], None] = self._optimizer.step + def progress(self, dataloader_iter: Iterator[In]) -> Out: if self._iter == 0: # Turn on sync collectives for PT2 pipeline. @@ -292,7 +314,7 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out: torch.sum(losses).backward() with record_function("## optimizer ##"): - self._optimizer.step() + self._optimizer_step() return output @@ -337,6 +359,7 @@ def __init__( custom_model_fwd: Optional[ Callable[[Optional[In]], Tuple[torch.Tensor, Out]] ] = None, + optimizer_compile_config: Optional[TorchCompileConfig] = None, ) -> None: self._model = model self._optimizer = optimizer @@ -399,6 +422,16 @@ def __init__( self._batch_ip2: Optional[In] = None self._context: TrainPipelineContext = context_type(version=0) + if optimizer_compile_config is not None: + self._optimizer_step: Callable[[], None] = torch.compile( + lambda: self._optimizer.step(), + fullgraph=optimizer_compile_config.fullgraph, + dynamic=optimizer_compile_config.dynamic, + backend=optimizer_compile_config.backend, + ) + else: + self._optimizer_step: Callable[[], None] = self._optimizer.step + def detach(self) -> torch.nn.Module: """ Detaches the model from sparse data dist (SDD) pipeline. @@ -530,7 +563,7 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out: # update with record_function("## optimizer ##"): - self._optimizer.step() + self._optimizer_step() self.dequeue_batch() return output @@ -757,6 +790,7 @@ def __init__( Callable[[Optional[In]], Tuple[torch.Tensor, Out]] ] = None, strict: bool = False, + optimizer_compile_config: Optional[TorchCompileConfig] = None, ) -> None: super().__init__( model=model, @@ -767,6 +801,7 @@ def __init__( context_type=EmbeddingTrainPipelineContext, pipeline_postproc=pipeline_postproc, custom_model_fwd=custom_model_fwd, + optimizer_compile_config=optimizer_compile_config, ) self._start_batch = start_batch self._stash_gradients = stash_gradients @@ -836,7 +871,7 @@ def _mlp_optimizer_step(self, current_batch: int) -> None: # special case: not all optimizers support optim.step() on null gradidents if current_batch == self._start_batch and self._stash_gradients: return - self._optimizer.step() + self._optimizer_step() def progress(self, dataloader_iter: Iterator[In]) -> Out: self.fill_pipeline(dataloader_iter) @@ -1056,6 +1091,7 @@ def __init__( custom_model_fwd: Optional[ Callable[[Optional[In]], Tuple[torch.Tensor, Out]] ] = None, + optimizer_compile_config: Optional[TorchCompileConfig] = None, ) -> None: super().__init__( model=model, @@ -1066,6 +1102,7 @@ def __init__( context_type=PrefetchTrainPipelineContext, pipeline_postproc=pipeline_postproc, custom_model_fwd=custom_model_fwd, + optimizer_compile_config=optimizer_compile_config, ) self._context = PrefetchTrainPipelineContext(version=0) self._prefetch_stream: Optional[torch.Stream] = ( @@ -1133,7 +1170,7 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out: # update with record_function("## optimizer ##"): - self._optimizer.step() + self._optimizer_step() self._start_sparse_data_dist(self._batch_ip2) @@ -1568,16 +1605,18 @@ def __init__( custom_model_fwd: Optional[ Callable[[Optional[In]], Tuple[torch.Tensor, Out]] ] = None, + optimizer_compile_config: Optional[TorchCompileConfig] = None, ) -> None: super().__init__( - model, - optimizer, - device, - execute_all_batches, - apply_jit, - context_type, - pipeline_postproc, - custom_model_fwd, + model=model, + optimizer=optimizer, + device=device, + execute_all_batches=execute_all_batches, + apply_jit=apply_jit, + context_type=context_type, + pipeline_postproc=pipeline_postproc, + custom_model_fwd=custom_model_fwd, + optimizer_compile_config=optimizer_compile_config, ) torch._logging.set_logs(compiled_autograd_verbose=True) @@ -1665,7 +1704,7 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out: # update with record_function("## optimizer ##"): - self._optimizer.step() + self._optimizer_step() self.dequeue_batch() return output