From 746f877ebcae8f0fdf86d5a7991b305cd81e7e3b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stig-Arne=20Gr=C3=B6nroos?= Date: Mon, 4 Dec 2023 09:55:14 +0200 Subject: [PATCH] Crude cuda memory profiling --- mammoth/distributed/communication.py | 8 ++++++++ mammoth/train_single.py | 3 +++ mammoth/trainer.py | 5 +++++ 3 files changed, 16 insertions(+) diff --git a/mammoth/distributed/communication.py b/mammoth/distributed/communication.py index 687da4b9..5072f277 100644 --- a/mammoth/distributed/communication.py +++ b/mammoth/distributed/communication.py @@ -11,6 +11,14 @@ from mammoth.utils.misc import set_random_seed +def debug_cuda_mem(name, rank): + logger.warning('rank {} cuda allocated {}: {}'.format( + rank, + name, + torch.cuda.memory_allocated(device=rank)) + ) + + def multi_init(opts, global_rank): dist_init_method = 'tcp://{master_ip}:{master_port}'.format(master_ip=opts.master_ip, master_port=opts.master_port) diff --git a/mammoth/train_single.py b/mammoth/train_single.py index f036f0d9..30c1dc33 100644 --- a/mammoth/train_single.py +++ b/mammoth/train_single.py @@ -12,6 +12,7 @@ from mammoth.utils.parse import ArgumentParser from mammoth.distributed import broadcast_tensors +from mammoth.distributed.communication import debug_cuda_mem from mammoth.inputters import DynamicDatasetIter from mammoth.transforms import get_transforms_cls @@ -144,6 +145,7 @@ def main( else: # Initialize some data structures _ = task_queue_manager.get_distributed_groups() + debug_cuda_mem('after init distributed', torch.distributed.get_rank()) enc, dec = model.count_parameters(log=logger.debug) logger.info("{} - total encoder parameters: {}".format(device_context.id, enc)) logger.info("{} - total decoder parameters: {}".format(device_context.id, dec)) @@ -156,6 +158,7 @@ def main( task_queue_manager=task_queue_manager, checkpoint=checkpoint, ) + debug_cuda_mem('after optimizer', torch.distributed.get_rank()) # Build model saver model_saver = build_model_saver(model_opts, opts, model, vocabs_dict, optim, device_context) diff --git a/mammoth/trainer.py b/mammoth/trainer.py index d5edf246..cdaa5898 100644 --- a/mammoth/trainer.py +++ b/mammoth/trainer.py @@ -18,6 +18,7 @@ from itertools import islice from mammoth.utils.logging import logger +from mammoth.distributed.communication import debug_cuda_mem def iter_on_device(iterator, device_context): @@ -281,6 +282,7 @@ def train( total_stats, report_stats, ) + debug_cuda_mem('after grad accum', torch.distributed.get_rank()) # Note that all group ids are tuples, some with length 1 for (layer_stack_index, encoder_id), (_, group) in self.my_encoder_groups.items(): @@ -334,6 +336,7 @@ def train( self._maybe_update_stats_from_parameters(report_stats, self.model.named_parameters()) self.optim.step() + debug_cuda_mem('after optim step', torch.distributed.get_rank()) self.optim.zero_grad() for p in self.model.parameters(): if hasattr(p, 'has_grad'): @@ -501,6 +504,7 @@ def _gradient_accumulation_over_lang_pairs( outputs, attns = self.model( src, tgt, src_lengths, bptt=bptt, with_align=self.with_align, metadata=metadata ) + debug_cuda_mem('after forward', torch.distributed.get_rank()) bptt = True # 3. Compute loss. @@ -513,6 +517,7 @@ def _gradient_accumulation_over_lang_pairs( trunc_start=j, trunc_size=trunc_size, ) + debug_cuda_mem('after loss', torch.distributed.get_rank()) # logger.info(loss) try: