From acff38fe96ec82a90ab9fb9d1e337a916034b4b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stig-Arne=20Gr=C3=B6nroos?= Date: Mon, 13 May 2024 15:05:15 +0300 Subject: [PATCH] Crude profiling of time waiting for the multiprocessing queue A semicolon separated csv format can be extracted from logs using the marker string "QUEUE_PERF". Both the semaphore and the put/get are measured, for both producer and consumer. --- mammoth/distributed/communication.py | 7 +++++++ mammoth/train_single.py | 6 ++++++ 2 files changed, 13 insertions(+) diff --git a/mammoth/distributed/communication.py b/mammoth/distributed/communication.py index c1ce3c71..57b71c71 100644 --- a/mammoth/distributed/communication.py +++ b/mammoth/distributed/communication.py @@ -3,6 +3,7 @@ import os import pickle import signal +import time import torch import torch.distributed @@ -213,11 +214,17 @@ def batch_producer(generator_to_serve, queue, semaphore, opts, device_id): logger.info(generator_to_serve) for batch, metadata, communication_batch_id in generator_to_serve: + start = time.time() semaphore.acquire() + duration = time.time() - start + logger.warning(f'QUEUE_PERF;producer_semaphore;{duration}') # Move batch to correspond device_id when consumer iterate # hack to dodge unpicklable `dict_keys` # batch.fields = list(batch.fields) + start = time.time() queue.put((batch, metadata, communication_batch_id)) + duration = time.time() - start + logger.warning(f'QUEUE_PERF;producer_put;{duration}') def consumer(process_fn, opts, device_context, error_queue, batch_queue, semaphore, task_queue_manager, checkpoint): diff --git a/mammoth/train_single.py b/mammoth/train_single.py index 72abab8c..fbb48fab 100644 --- a/mammoth/train_single.py +++ b/mammoth/train_single.py @@ -145,8 +145,14 @@ def main( def _train_iter(): while True: + start = time.time() batch, metadata, communication_batch_id = batch_queue.get() + duration = time.time() - start + logger.warning(f'QUEUE_PERF;consumer_get;{duration}') + start = time.time() semaphore.release() + duration = time.time() - start + logger.warning(f'QUEUE_PERF;consumer_semaphore;{duration}') # TODO: confirm that batch-providing corpus has already been to'd to the correct place yield batch, metadata, communication_batch_id