From 80b508b6e5470c878bf51168bb8e2f4020b1bed6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stig-Arne=20Gr=C3=B6nroos?= Date: Mon, 26 Feb 2024 19:56:32 +0200 Subject: [PATCH] Initialize TaskDistributionStrategy in global_to_local This guarantees that the producer and consumer don't share state --- mammoth/distributed/tasks.py | 16 ++++++++-------- mammoth/trainer.py | 1 - 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/mammoth/distributed/tasks.py b/mammoth/distributed/tasks.py index eab23802..76cef7ad 100644 --- a/mammoth/distributed/tasks.py +++ b/mammoth/distributed/tasks.py @@ -172,8 +172,8 @@ def __init__( tasks: List[TaskSpecs], accum_count: int, world_context: WorldContext, + task_distribution_strategy_cls: type, distributed_components=None, - task_distribution_strategy: Optional[TaskDistributionStrategy] = None, uses_adapters: bool = False, ): """ @@ -190,7 +190,7 @@ def __init__( self.tasks = tasks # TODO: no support for variable accumulation across training self.accum_count = accum_count[0] if isinstance(accum_count, list) else accum_count - self.task_distribution_strategy = task_distribution_strategy + self.task_distribution_strategy_cls = task_distribution_strategy_cls self.world_context = world_context self.uses_adapters = uses_adapters @@ -244,9 +244,7 @@ def from_opts(cls, opts: Namespace, world_context: WorldContext): if not len(opts.dec_layers) == 1: raise Exception('With more than one decoder stack, you must explictly define dec_sharing_group') - task_distribution_strategy = TASK_DISTRIBUTION_STRATEGIES[opts.task_distribution_strategy]( - seed=opts.seed, - ) + task_distribution_strategy_cls = TASK_DISTRIBUTION_STRATEGIES[opts.task_distribution_strategy] tasks = [] uses_adapters = False for ( @@ -290,7 +288,7 @@ def from_opts(cls, opts: Namespace, world_context: WorldContext): tasks, world_context=world_context, accum_count=opts.accum_count, - task_distribution_strategy=task_distribution_strategy, + task_distribution_strategy_cls=task_distribution_strategy_cls, uses_adapters=uses_adapters, ) @@ -298,12 +296,13 @@ def global_to_local(self, node_rank, local_rank, opts): assert node_rank is not None assert local_rank is not None device_context = self.world_context.global_to_local(node_rank, local_rank) + task_distribution_strategy = self.task_distribution_strategy_cls(seed=opts.seed) return LocalTaskQueueManager( self.tasks, accum_count=self.accum_count, world_context=self.world_context, distributed_components=self.distributed_components, - task_distribution_strategy=self.task_distribution_strategy, + task_distribution_strategy=task_distribution_strategy, uses_adapters=self.uses_adapters, device_context=device_context, ) @@ -482,13 +481,14 @@ def __init__( tasks=tasks, accum_count=accum_count, world_context=world_context, - task_distribution_strategy=task_distribution_strategy, + task_distribution_strategy_cls=task_distribution_strategy.__class__, uses_adapters=uses_adapters, distributed_components=distributed_components, ) assert device_context is not None self.device_context = device_context + self.task_distribution_strategy = task_distribution_strategy logger.info(f'in task_queue_manager: node_rank {self.node_rank} local_rank {self.local_rank}') self.device_context.validate(self.world_context) diff --git a/mammoth/trainer.py b/mammoth/trainer.py index 80390da7..cc0781e5 100644 --- a/mammoth/trainer.py +++ b/mammoth/trainer.py @@ -270,7 +270,6 @@ def train( batches_with_meta = islice(train_iter, self.accum_count) batch_task_sample = self.task_queue_manager.sample_corpus_ids() - logger.info(f'batch_task_sample has {batch_task_sample.training_step}') my_task = batch_task_sample.tasks[self.task_queue_manager.global_rank] self._gradient_accumulation(