Skip to content

Commit

Permalink
Crude cuda memory profiling
Browse files Browse the repository at this point in the history
  • Loading branch information
Waino committed Dec 4, 2023
1 parent e6ce860 commit 746f877
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 0 deletions.
8 changes: 8 additions & 0 deletions mammoth/distributed/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@
from mammoth.utils.misc import set_random_seed


def debug_cuda_mem(name, rank):
logger.warning('rank {} cuda allocated {}: {}'.format(
rank,
name,
torch.cuda.memory_allocated(device=rank))
)


def multi_init(opts, global_rank):
dist_init_method = 'tcp://{master_ip}:{master_port}'.format(master_ip=opts.master_ip, master_port=opts.master_port)

Expand Down
3 changes: 3 additions & 0 deletions mammoth/train_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from mammoth.utils.parse import ArgumentParser

from mammoth.distributed import broadcast_tensors
from mammoth.distributed.communication import debug_cuda_mem
from mammoth.inputters import DynamicDatasetIter
from mammoth.transforms import get_transforms_cls

Expand Down Expand Up @@ -144,6 +145,7 @@ def main(
else:
# Initialize some data structures
_ = task_queue_manager.get_distributed_groups()
debug_cuda_mem('after init distributed', torch.distributed.get_rank())
enc, dec = model.count_parameters(log=logger.debug)
logger.info("{} - total encoder parameters: {}".format(device_context.id, enc))
logger.info("{} - total decoder parameters: {}".format(device_context.id, dec))
Expand All @@ -156,6 +158,7 @@ def main(
task_queue_manager=task_queue_manager,
checkpoint=checkpoint,
)
debug_cuda_mem('after optimizer', torch.distributed.get_rank())

# Build model saver
model_saver = build_model_saver(model_opts, opts, model, vocabs_dict, optim, device_context)
Expand Down
5 changes: 5 additions & 0 deletions mammoth/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from itertools import islice
from mammoth.utils.logging import logger
from mammoth.distributed.communication import debug_cuda_mem


def iter_on_device(iterator, device_context):
Expand Down Expand Up @@ -281,6 +282,7 @@ def train(
total_stats,
report_stats,
)
debug_cuda_mem('after grad accum', torch.distributed.get_rank())

# Note that all group ids are tuples, some with length 1
for (layer_stack_index, encoder_id), (_, group) in self.my_encoder_groups.items():
Expand Down Expand Up @@ -334,6 +336,7 @@ def train(
self._maybe_update_stats_from_parameters(report_stats, self.model.named_parameters())

self.optim.step()
debug_cuda_mem('after optim step', torch.distributed.get_rank())
self.optim.zero_grad()
for p in self.model.parameters():
if hasattr(p, 'has_grad'):
Expand Down Expand Up @@ -501,6 +504,7 @@ def _gradient_accumulation_over_lang_pairs(
outputs, attns = self.model(
src, tgt, src_lengths, bptt=bptt, with_align=self.with_align, metadata=metadata
)
debug_cuda_mem('after forward', torch.distributed.get_rank())
bptt = True

# 3. Compute loss.
Expand All @@ -513,6 +517,7 @@ def _gradient_accumulation_over_lang_pairs(
trunc_start=j,
trunc_size=trunc_size,
)
debug_cuda_mem('after loss', torch.distributed.get_rank())
# logger.info(loss)

try:
Expand Down

0 comments on commit 746f877

Please sign in to comment.