Skip to content

Commit

Permalink
Initialize TaskDistributionStrategy in global_to_local
Browse files Browse the repository at this point in the history
This guarantees that the producer and consumer don't share state
  • Loading branch information
Waino committed Feb 26, 2024
1 parent 98366a3 commit 80b508b
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 9 deletions.
16 changes: 8 additions & 8 deletions mammoth/distributed/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,8 @@ def __init__(
tasks: List[TaskSpecs],
accum_count: int,
world_context: WorldContext,
task_distribution_strategy_cls: type,
distributed_components=None,
task_distribution_strategy: Optional[TaskDistributionStrategy] = None,
uses_adapters: bool = False,
):
"""
Expand All @@ -190,7 +190,7 @@ def __init__(
self.tasks = tasks
# TODO: no support for variable accumulation across training
self.accum_count = accum_count[0] if isinstance(accum_count, list) else accum_count
self.task_distribution_strategy = task_distribution_strategy
self.task_distribution_strategy_cls = task_distribution_strategy_cls
self.world_context = world_context
self.uses_adapters = uses_adapters

Expand Down Expand Up @@ -244,9 +244,7 @@ def from_opts(cls, opts: Namespace, world_context: WorldContext):
if not len(opts.dec_layers) == 1:
raise Exception('With more than one decoder stack, you must explictly define dec_sharing_group')

task_distribution_strategy = TASK_DISTRIBUTION_STRATEGIES[opts.task_distribution_strategy](
seed=opts.seed,
)
task_distribution_strategy_cls = TASK_DISTRIBUTION_STRATEGIES[opts.task_distribution_strategy]
tasks = []
uses_adapters = False
for (
Expand Down Expand Up @@ -290,20 +288,21 @@ def from_opts(cls, opts: Namespace, world_context: WorldContext):
tasks,
world_context=world_context,
accum_count=opts.accum_count,
task_distribution_strategy=task_distribution_strategy,
task_distribution_strategy_cls=task_distribution_strategy_cls,
uses_adapters=uses_adapters,
)

def global_to_local(self, node_rank, local_rank, opts):
assert node_rank is not None
assert local_rank is not None
device_context = self.world_context.global_to_local(node_rank, local_rank)
task_distribution_strategy = self.task_distribution_strategy_cls(seed=opts.seed)
return LocalTaskQueueManager(
self.tasks,
accum_count=self.accum_count,
world_context=self.world_context,
distributed_components=self.distributed_components,
task_distribution_strategy=self.task_distribution_strategy,
task_distribution_strategy=task_distribution_strategy,
uses_adapters=self.uses_adapters,
device_context=device_context,
)
Expand Down Expand Up @@ -482,13 +481,14 @@ def __init__(
tasks=tasks,
accum_count=accum_count,
world_context=world_context,
task_distribution_strategy=task_distribution_strategy,
task_distribution_strategy_cls=task_distribution_strategy.__class__,
uses_adapters=uses_adapters,
distributed_components=distributed_components,
)

assert device_context is not None
self.device_context = device_context
self.task_distribution_strategy = task_distribution_strategy

logger.info(f'in task_queue_manager: node_rank {self.node_rank} local_rank {self.local_rank}')
self.device_context.validate(self.world_context)
Expand Down
1 change: 0 additions & 1 deletion mammoth/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,6 @@ def train(
batches_with_meta = islice(train_iter, self.accum_count)

batch_task_sample = self.task_queue_manager.sample_corpus_ids()
logger.info(f'batch_task_sample has {batch_task_sample.training_step}')
my_task = batch_task_sample.tasks[self.task_queue_manager.global_rank]

self._gradient_accumulation(
Expand Down

0 comments on commit 80b508b

Please sign in to comment.