Skip to content

Commit

Permalink
WIP: refactor distributed components
Browse files Browse the repository at this point in the history
Tests are passing, but the modification is still only part way
  • Loading branch information
Waino committed Feb 26, 2024
1 parent 5d6a606 commit 3ec9f22
Show file tree
Hide file tree
Showing 11 changed files with 382 additions and 327 deletions.
5 changes: 3 additions & 2 deletions mammoth/distributed/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ def add(self, component):
else:
# already seen component must be merged
old_component = self.components[name]
assert type(old_component) == type(component)
assert type(old_component) == type(component), \
f'Unexpected type {name}: {old_component} != {component}'
assert old_component.group is None
assert component.group is None
old_component.global_ranks.update(component.global_ranks)
Expand Down Expand Up @@ -103,7 +104,7 @@ def get_module(self, model: NMTModel) -> nn.Module:
class DistributedDecoder(DistributedXCoder):
@property
def side(self) -> Side:
return Side.encoder
return Side.decoder

@property
def decoder_id(self) -> str:
Expand Down
184 changes: 55 additions & 129 deletions mammoth/distributed/tasks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""sub-module defining tasks, task specifications and task management objects."""
from abc import ABC, abstractmethod
from argparse import Namespace
from collections import OrderedDict, namedtuple, defaultdict, Counter
from collections import namedtuple, defaultdict, Counter
from dataclasses import dataclass
from itertools import cycle, islice
from pprint import pformat
Expand All @@ -22,8 +22,8 @@
DistributedGenerator,
DistributedAdapter,
DistributedAttentionBridge,
DistributedComponentAction,
DistributedComponentActionWithGradient,
# DistributedComponentAction,
# DistributedComponentActionWithGradient,
)
from mammoth.utils.logging import logger

Expand Down Expand Up @@ -94,18 +94,6 @@ def sample_corpus_ids(self, active_tasks: Dict[int, List[TaskSpecs]]) -> BatchTa
pass


# # Sanity check of weights and curriculum
# assert len(self.my_corpus_ids) == len(self.my_weights)
# assert len(self.my_corpus_ids) == len(self.my_introduce_at_training_step)
# if len(self.my_corpus_ids) == 0:
# raise ValueError('No corpora on device')
# if sum(my_weights) <= 0:
# raise ValueError('Can not set "weight" of all corpora on a device to zero')
# if all(x > 0 for x in my_introduce_at_training_step):
# raise ValueError('Can not set "introduce_at_training_step" of all corpora on a device to nonzero')
# if all(weight == 0 or start > 0 for (weight, start) in zip(my_weights, my_introduce_at_training_step)):
# raise ValueError('Invalid curriculum: no corpus is ready to start in the first step')

class WeightedSamplingTaskDistributionStrategy(TaskDistributionStrategy):
"""
Schedules tasks by sampling with replacement from a categorical distribution.
Expand All @@ -124,7 +112,8 @@ def __init__(
def sample_corpus_ids(self, active_tasks: Dict[int, List[TaskSpecs]]) -> BatchTaskSample:
result: Dict[int, str] = dict()
for global_rank in sorted(active_tasks.keys()):
tasks, weights = zip(*active_tasks[global_rank])
tasks = active_tasks[global_rank]
weights = [task.weight for task in tasks]
sum_w = sum(weights)
assert sum_w > 0
p = [weight / sum_w for weight in weights]
Expand All @@ -144,13 +133,12 @@ class RoundRobinTaskDistributionStrategy(TaskDistributionStrategy):
"""

def __init__(self, seed: int):
super().__init__(seed)
super().__init__()

def sample_corpus_ids(self, active_tasks: Dict[int, List[TaskSpecs]]) -> BatchTaskSample:
self.training_step += 1
result: Dict[int, str] = dict()
for global_rank in sorted(active_tasks.keys()):
tasks, _ = zip(*active_tasks[global_rank])
tasks = active_tasks[global_rank]
sampled_corpus_id = tasks[self.training_step % len(tasks)]
result[global_rank] = sampled_corpus_id
bts = BatchTaskSample(tasks=result, training_step=self.training_step)
Expand Down Expand Up @@ -184,8 +172,7 @@ def __init__(
tasks: List[TaskSpecs],
accum_count: int,
world_context: WorldContext,
components_to_gpus=None,
components_to_groups=None,
distributed_components=None,
task_distribution_strategy: Optional[TaskDistributionStrategy] = None,
uses_adapters: bool = False,
):
Expand All @@ -207,9 +194,7 @@ def __init__(
self.world_context = world_context
self.uses_adapters = uses_adapters

self.components_to_gpus = components_to_gpus
self.components_to_groups = components_to_groups
self.sampled_task_counts = Counter()
self.distributed_components = distributed_components

@property
def gpus_per_node(self):
Expand Down Expand Up @@ -259,7 +244,9 @@ def from_opts(cls, opts: Namespace, world_context: WorldContext):
if not len(opts.dec_layers) == 1:
raise Exception('With more than one decoder stack, you must explictly define dec_sharing_group')

task_distribution_strategy = TASK_DISTRIBUTION_STRATEGIES[opts.task_distribution_strategy]()
task_distribution_strategy = TASK_DISTRIBUTION_STRATEGIES[opts.task_distribution_strategy](
seed=opts.seed,
)
tasks = []
uses_adapters = False
for (
Expand Down Expand Up @@ -315,8 +302,7 @@ def global_to_local(self, node_rank, local_rank, opts):
self.tasks,
accum_count=self.accum_count,
world_context=self.world_context,
components_to_gpus=self.components_to_gpus,
components_to_groups=self.components_to_groups,
distributed_components=self.distributed_components,
task_distribution_strategy=self.task_distribution_strategy,
uses_adapters=self.uses_adapters,
device_context=device_context,
Expand Down Expand Up @@ -366,7 +352,7 @@ def create_all_distributed_components(
self,
use_attention_bridge: bool,
new_group_func=torch.distributed.new_group,
):
) -> List[DistributedComponent]:
"""
Creates DistributedComponent objects.
For all components that are on more than one device, creats a communication group.
Expand Down Expand Up @@ -458,6 +444,7 @@ def create_all_distributed_components(
logger.info(f'{component.get_name()} is on a single device')

sorted_components = list(builder)
self.distributed_components = sorted_components
return sorted_components

def get_langs(self, side):
Expand All @@ -475,8 +462,7 @@ def __init__(
tasks: List[TaskSpecs],
accum_count: int,
world_context: WorldContext,
components_to_gpus=None,
components_to_groups=None,
distributed_components=None,
task_distribution_strategy: TaskDistributionStrategy = None,
uses_adapters: bool = False,
device_context: Optional[DeviceContext] = None,
Expand All @@ -498,17 +484,36 @@ def __init__(
world_context=world_context,
task_distribution_strategy=task_distribution_strategy,
uses_adapters=uses_adapters,
components_to_gpus=components_to_gpus,
components_to_groups=components_to_groups,
distributed_components=distributed_components,
)

assert device_context is not None
self.device_context = device_context

logger.info(f'in task_queue_manager: node_rank {self.node_rank} local_rank {self.local_rank}')
self.device_context.validate(self.world_context)
self._sanity_check_tasks()

self.sampled_task_counts = Counter()
self.my_distributed_components = None

def _sanity_check_tasks(self):
my_corpus_ids = [task.corpus_id for task in self.get_my_tasks()]
my_weights = [task.weight for task in self.get_my_tasks()]
my_introduce_at_training_step = [
task.introduce_at_training_step for task in self.get_my_tasks()
]
# Sanity check of weights and curriculum
assert len(my_corpus_ids) == len(my_weights)
assert len(my_corpus_ids) == len(my_introduce_at_training_step)
if len(my_corpus_ids) == 0:
raise ValueError('No corpora on device')
if sum(my_weights) <= 0:
raise ValueError('Can not set "weight" of all corpora on a device to zero')
if all(x > 0 for x in my_introduce_at_training_step):
raise ValueError('Can not set "introduce_at_training_step" of all corpora on a device to nonzero')
if all(weight == 0 or start > 0 for (weight, start) in zip(my_weights, my_introduce_at_training_step)):
raise ValueError('Invalid curriculum: no corpus is ready to start in the first step')

@property
def node_rank(self):
Expand All @@ -522,106 +527,27 @@ def local_rank(self):
def global_rank(self):
return self.node_rank * self.gpus_per_node + self.local_rank

def get_my_distributed_groups(
self,
new_group_func=torch.distributed.new_group,
):
"""
Returns pairs of (component_id, process_group).
Only components present on this GPU are returned.
The pairs are returned in a consistent order across GPUs.
"""
if self.components_to_groups is None:
self.create_all_distributed_groups(new_group_func)
logger.info(f'components_to_groups: {self.components_to_groups}')

my_distributed_groups = {
'encoder': OrderedDict(),
'decoder': OrderedDict(),
'src_emb': OrderedDict(),
'tgt_emb': OrderedDict(),
'encoder_adapters': OrderedDict(),
'decoder_adapters': OrderedDict(),
}

if self.global_rank is None:
# Training on CPU, or called on global TaskQueueManager
for component_type, components in self.components_to_groups.items():
my_distributed_groups[component_type] = components

global_rank = self.global_rank

for key, global_ranks in self.components_to_gpus.items():
if global_rank not in global_ranks:
# omit groups that are not on this device
continue
component_type = key[0]
component_id = key[1:]
if component_id not in self.components_to_groups[component_type]:
# omit components on a single device
logger.info(f'{component_type} {component_id} is on a single device')
continue
my_distributed_groups[component_type][component_id] = \
self.components_to_groups[component_type][component_id]

return my_distributed_groups

def get_my_grouped_components(self, model):
"""
Returns nested dict of component_type -> component_id -> nn.Module.
Only components present on this GPU are returned.
Unlike get_my_distributed_groups, this method also returns components on a single device,
and it does not retrieve communication groups.
"""
if self.components_to_groups is None:
raise Exception('Must call get_my_distributed_groups first')

my_grouped_components = {
'encoder': OrderedDict(),
'decoder': OrderedDict(),
'src_emb': OrderedDict(),
'tgt_emb': OrderedDict(),
'encoder_adapters': OrderedDict(),
'decoder_adapters': OrderedDict(),
}

if not self.world_context.is_distributed():
tasks = self.tasks
else:
tasks = self.get_my_tasks()

for task in tasks:
# loop over my tasks, getting all the relevant module ids and modules
my_grouped_components['src_emb'][task.src_lang] = model.encoder.embeddings[f'embeddings_{task.src_lang}']
my_grouped_components['tgt_emb'][task.tgt_lang] = model.decoder.embeddings[f'embeddings_{task.tgt_lang}']
for layer_stack_index, encoder_id in enumerate(task.encoder_id):
component = model.encoder.get_submodule(layer_stack_index, encoder_id)
my_grouped_components['encoder'][(layer_stack_index, encoder_id)] = component
for layer_stack_index, decoder_id in enumerate(task.decoder_id):
component = model.decoder.get_submodule(layer_stack_index, decoder_id)
my_grouped_components['decoder'][(layer_stack_index, decoder_id)] = component
if task.encoder_adapter_ids:
for layer_stack_index, adapter_group, sub_id in task.encoder_adapter_ids:
encoder_id = task.encoder_id[layer_stack_index]
key = (layer_stack_index, encoder_id, adapter_group, sub_id)
component = model.encoder.get_submodule(
layer_stack_index, encoder_id
).get_adapter(adapter_group, sub_id)
my_grouped_components['encoder_adapters'][key] = component
if task.decoder_adapter_ids:
for layer_stack_index, adapter_group, sub_id in task.decoder_adapter_ids:
decoder_id = task.decoder_id[layer_stack_index]
key = (layer_stack_index, decoder_id, adapter_group, sub_id)
component = model.decoder.get_submodule(
layer_stack_index, decoder_id
).get_adapter(adapter_group, sub_id)
my_grouped_components['decoder_adapters'][key] = component

return my_grouped_components
def get_my_distributed_components(self) -> List[DistributedComponent]:
if self.distributed_components is None:
raise Exception('Call create_all_distributed_components first')
if not self.my_distributed_components:
my_global_rank = self.global_rank
self.my_distributed_components = [
component
for component in self.distributed_components
if my_global_rank in component.global_ranks
]
return self.my_distributed_components

def sample_corpus_ids(self):
def sample_corpus_ids(self) -> BatchTaskSample:
active_tasks: Dict[int, List[TaskSpecs]] = self.get_active_tasks()
batch_task_sample = self.task_distribution_strategy.sample_corpus_ids(active_tasks)
if self.global_rank is None or self.global_rank == 0:
# Only track sampled_task_counts on the master device.
# Every TQM (both data loader and trainer for every device) has access to the global info
self.sampled_task_counts.update(
[task.corpus_id for task in batch_task_sample.tasks.values()]
)
return batch_task_sample

def get_my_encoders(self, layer_stack_index: int):
Expand Down
20 changes: 6 additions & 14 deletions mammoth/inputters/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,14 +337,13 @@ def __iter__(self):
yield from itertools.chain.from_iterable(all_val_data)

else:
# All minibatches with the same communication_batch_id should be trained on
# before synching gradients between devices
communication_batch_id = 0
while True:
for corpus_id in self.task_queue_manager.sample_corpus_ids(communication_batch_id):
ordered_iter, metadata = self.dataset_iterators[corpus_id]
batch_task_sample = self.task_queue_manager.sample_corpus_ids()
my_task = batch_task_sample.tasks[self.task_queue_manager.global_rank]
ordered_iter, metadata = self.dataset_iterators[my_task.corpus_id]
for _ in self.task_queue_manager.accum_count:
batch = next(ordered_iter)
if communication_batch_id == 0:
if batch_task_sample.training_step == 0:
# De-numericalize a few sentences for debugging
logger.warning(
f'src shape: {batch.src[0].shape} tgt shape: {batch.tgt.shape} '
Expand All @@ -357,11 +356,4 @@ def __iter__(self):
logger.warning(f'{sent_idx} {metadata.src_lang} src: {" ".join(toks)}')
toks = [tgt_vocab.itos[tok_id.item()] for tok_id in batch.tgt[:, sent_idx, 0]]
logger.warning(f'{sent_idx} {metadata.tgt_lang} tgt: {" ".join(toks)}')
yield batch, metadata, communication_batch_id

communication_batch_id += 1
if communication_batch_id % 1000 == 0:
total = sum(self.task_queue_manager.sampled_task_counts.values())
logger.info(f'Task sampling distribution: (total {total})')
for task, count in self.task_queue_manager.sampled_task_counts.most_common():
logger.info(f'Task: {task}\tcount: {count}\t{100 * count / total} %')
yield batch, metadata, batch_task_sample.training_step
2 changes: 1 addition & 1 deletion mammoth/modules/layer_stack_encoder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from torch import nn
from typing import Dict, List, Optional
from typing import Dict, List

from mammoth.modules.encoder import EncoderBase
from mammoth.models.adapters import Adapter, AdaptedTransformerEncoder
Expand Down
5 changes: 3 additions & 2 deletions mammoth/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,9 @@ def _add_reproducibility_opts(parser):
'--seed',
'-seed',
type=int,
default=-1,
help="Set random seed used for better reproducibility between experiments.",
required=True,
help="Set random seed used for better reproducibility between experiments. "
"Mandatory for multi-gpu training, and for convenience required for all.",
)


Expand Down
5 changes: 4 additions & 1 deletion mammoth/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
mammoth.opts._add_train_general_opts(parser)

# -data option is required, but not used in this test, so dummy.
opts = parser.parse_known_args(['-tasks', 'dummy', '-node_rank', '0', '-model_dim', '500'], strict=False)[0]
opts = parser.parse_known_args(
['-tasks', 'dummy', '-node_rank', '0', '-model_dim', '500', '-seed', '1'],
strict=False
)[0]


class TestModel(unittest.TestCase):
Expand Down
Loading

0 comments on commit 3ec9f22

Please sign in to comment.