diff --git a/mammoth/distributed/communication.py b/mammoth/distributed/communication.py index c1ce3c71..57b71c71 100644 --- a/mammoth/distributed/communication.py +++ b/mammoth/distributed/communication.py @@ -3,6 +3,7 @@ import os import pickle import signal +import time import torch import torch.distributed @@ -213,11 +214,17 @@ def batch_producer(generator_to_serve, queue, semaphore, opts, device_id): logger.info(generator_to_serve) for batch, metadata, communication_batch_id in generator_to_serve: + start = time.time() semaphore.acquire() + duration = time.time() - start + logger.warning(f'QUEUE_PERF;producer_semaphore;{duration}') # Move batch to correspond device_id when consumer iterate # hack to dodge unpicklable `dict_keys` # batch.fields = list(batch.fields) + start = time.time() queue.put((batch, metadata, communication_batch_id)) + duration = time.time() - start + logger.warning(f'QUEUE_PERF;producer_put;{duration}') def consumer(process_fn, opts, device_context, error_queue, batch_queue, semaphore, task_queue_manager, checkpoint): diff --git a/mammoth/train_single.py b/mammoth/train_single.py index 72abab8c..fbb48fab 100644 --- a/mammoth/train_single.py +++ b/mammoth/train_single.py @@ -145,8 +145,14 @@ def main( def _train_iter(): while True: + start = time.time() batch, metadata, communication_batch_id = batch_queue.get() + duration = time.time() - start + logger.warning(f'QUEUE_PERF;consumer_get;{duration}') + start = time.time() semaphore.release() + duration = time.time() - start + logger.warning(f'QUEUE_PERF;consumer_semaphore;{duration}') # TODO: confirm that batch-providing corpus has already been to'd to the correct place yield batch, metadata, communication_batch_id