diff --git a/mammoth/bin/translate.py b/mammoth/bin/translate.py index 63a5ed1c..4deb9e2e 100644 --- a/mammoth/bin/translate.py +++ b/mammoth/bin/translate.py @@ -38,6 +38,7 @@ def translate(opts): decoder_id=decoder_id, corpus_id=corpus_id, weight=1.0, + introduce_at_training_step=0, corpus_opts=corpus_opts, src_vocab=None, tgt_vocab=None, diff --git a/mammoth/distributed/components.py b/mammoth/distributed/components.py index 30237c03..c38619ae 100644 --- a/mammoth/distributed/components.py +++ b/mammoth/distributed/components.py @@ -2,7 +2,33 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum, auto -from typing import List, Any, Optional +from typing import Set, Any, Optional, Dict + +from mammoth.models import NMTModel + + +class DistributedComponentBuilder: + def __init__(self): + self.components: Dict[str, DistributedComponent] = dict() + + def add(self, component): + name = component.get_name() + if name not in self.components: + # new component + self.components[name] = component + else: + # already seen component must be merged + old_component = self.components[name] + assert type(old_component) == type(component) + assert old_component.group is None + assert component.group is None + old_component.global_ranks.update(component.global_ranks) + + def __iter__(self): + result = [] + for key in sorted(self.components.keys()): + result.append(self.components[key]) + return iter(result) class Side(Enum): @@ -16,8 +42,13 @@ class DistributedComponent(ABC): Represents a model component that may be distributed across several devices according to some parameter sharing pattern. """ - module: nn.Module - ranks: List[int] + # This was implemented as a separate dataclass instead of making it a mixin + # of the nn.Module. The main reason is the need to create and use the + # DistributedComponents also in contexts where an initialized model is not + # (yet) available: 1) in the dataloader, 2) (after future refactoring) when + # creating the Modules that the model consists of. + + global_ranks: Set[int] # distributed communication group object, or None if on a single device group: Optional[Any] @@ -25,30 +56,63 @@ class DistributedComponent(ABC): def get_name(self) -> str: pass - def named_parameters(self): - yield from self.module.named_parameters() + @abstractmethod + def get_module(self, model: NMTModel) -> nn.Module: + pass + + def named_parameters(self, model: NMTModel): + module = self.get_module(model) + yield from module.named_parameters() def min_rank(self) -> int: - return min(self.ranks) + return min(self.global_ranks) @dataclass -class DistributedXCoder(DistributedComponent): - side: Side +class DistributedXCoder(DistributedComponent, ABC): layer_stack_index: int xcoder_id: str def get_name(self) -> str: return f'{self.side.name}_{self.layer_stack_index}_{self.xcoder_id}' - def named_parameters(self): - for name, p in self.module.named_parameters(): + def named_parameters(self, model: NMTModel): + module = self.get_module(model) + for name, p in module.named_parameters(): # encoders and decoders contain embeddings and adapters as submodules # however, we want to treat these as distinct DistributedComponents if 'embeddings' not in name and 'adapter' not in name: yield name, p +@dataclass +class DistributedEncoder(DistributedXCoder): + @property + def side(self) -> Side: + return Side.encoder + + @property + def encoder_id(self) -> str: + return self.xcoder_id + + def get_module(self, model: NMTModel) -> nn.Module: + return model.encoder.get_submodule(self.layer_stack_index, self.xcoder_id) + + +@dataclass +class DistributedDecoder(DistributedXCoder): + @property + def side(self) -> Side: + return Side.encoder + + @property + def decoder_id(self) -> str: + return self.xcoder_id + + def get_module(self, model: NMTModel) -> nn.Module: + return model.decoder.get_submodule(self.layer_stack_index, self.xcoder_id) + + @dataclass class DistributedEmbedding(DistributedComponent): side: Side @@ -58,6 +122,12 @@ def get_name(self) -> str: side_str = 'src' if self.side == Side.encoder else 'tgt' return f'{side_str}_embeddings_{self.lang}' + def get_module(self, model: NMTModel) -> nn.Module: + if self.side == Side.encoder: + return model.encoder.embeddings[f'embeddings_{self.lang}'] + else: + return model.decoder.embeddings[f'embeddings_{self.lang}'] + @dataclass class DistributedGenerator(DistributedComponent): @@ -66,9 +136,14 @@ class DistributedGenerator(DistributedComponent): def get_name(self) -> str: return f'generator_{self.lang}' + def get_module(self, model: NMTModel) -> nn.Module: + return model.generator[f'generator_{self.lang}'] + @dataclass class DistributedAdapter(DistributedComponent): + # Can't use parent object of type DistributedXCoder: that refers to a + # specific module, while the adapter is for the entire layerstack slot side: Side layer_stack_index: int adapter_group: str @@ -77,12 +152,21 @@ class DistributedAdapter(DistributedComponent): def get_name(self) -> str: return f'{self.side.name}_adapter_{self.layer_stack_index}_{self.adapter_group}_{self.sub_id}' + def get_module(self, model: NMTModel) -> nn.Module: + if self.side == Side.encoder: + model.encoder.get_adapter(self.adapter_group, self.sub_id) + else: + model.decoder.get_adapter(self.adapter_group, self.sub_id) + @dataclass class DistributedAttentionBridge(DistributedComponent): def get_name(self) -> str: return 'attention_bridge' + def get_module(self, model: NMTModel) -> Optional[nn.Module]: + return self.model.attention_bridge + @dataclass class DistributedComponentAction: @@ -94,7 +178,7 @@ class DistributedComponentAction: @dataclass -class DistributedComponentActionGradient(DistributedComponentAction): +class DistributedComponentActionWithGradient(DistributedComponentAction): # True: has a real gradient that needs to be communicated # False: send a zero dummy gradient, receive gradient from others has_local_gradient: bool diff --git a/mammoth/distributed/tasks.py b/mammoth/distributed/tasks.py index c5502625..6e084540 100644 --- a/mammoth/distributed/tasks.py +++ b/mammoth/distributed/tasks.py @@ -1,124 +1,33 @@ """sub-module defining tasks, task specifications and task management objects.""" from abc import ABC, abstractmethod from argparse import Namespace -from collections import OrderedDict, namedtuple, Counter +from collections import OrderedDict, namedtuple, defaultdict, Counter from dataclasses import dataclass from itertools import cycle, islice from pprint import pformat -from typing import Any, Optional, List, Tuple +from typing import Any, Optional, List, Tuple, Dict import numpy as np import torch import torch.distributed from mammoth.distributed.contexts import DeviceContext, WorldContext +from mammoth.distributed.components import ( + Side, + DistributedComponentBuilder, + DistributedComponent, + DistributedEncoder, + DistributedDecoder, + DistributedEmbedding, + DistributedGenerator, + DistributedAdapter, + DistributedAttentionBridge, + DistributedComponentAction, + DistributedComponentActionWithGradient, +) from mammoth.utils.logging import logger -class TaskDistributionStrategy(ABC): - """ - An abstract task distribution strategy, controls which task will be scheduled next. - """ - @abstractmethod - def __init__(self, my_corpus_ids: List[str], **kwargs): - pass - - @classmethod - @abstractmethod - def from_opts(cls, my_corpus_ids: List[str], opts: dict): - """Alternative constructor.""" - pass - - @abstractmethod - def sample_corpus_ids(self, n_samples: int, communication_batch_id: int) -> List[str]: - """Select corpora to sample from.""" - pass - - -class WeightedSamplingTaskDistributionStrategy(TaskDistributionStrategy): - """ - Schedules tasks by sampling with replacement from a categorical distribution. - The probabilities are found by normalizing the weights of all valid tasks (corpora). - Valid tasks are those that are present on this device, and have already reached - their curriculum starting point "introduce_at_training_step". - """ - - def __init__( - self, - my_corpus_ids: List[str], - my_weights: List[float], - my_introduce_at_training_step: List[int] - ): - self.my_corpus_ids = my_corpus_ids - self.my_weights = my_weights - self.my_introduce_at_training_step = my_introduce_at_training_step - - # Sanity check of weights and curriculum - assert len(self.my_corpus_ids) == len(self.my_weights) - assert len(self.my_corpus_ids) == len(self.my_introduce_at_training_step) - if len(self.my_corpus_ids) == 0: - raise ValueError('No corpora on device') - if sum(my_weights) <= 0: - raise ValueError('Can not set "weight" of all corpora on a device to zero') - if all(x > 0 for x in my_introduce_at_training_step): - raise ValueError('Can not set "introduce_at_training_step" of all corpora on a device to nonzero') - if all(weight == 0 or start > 0 for (weight, start) in zip(my_weights, my_introduce_at_training_step)): - raise ValueError('Invalid curriculum: no corpus is ready to start in the first step') - - @classmethod - def from_opts(cls, my_corpus_ids: List[str], opts: dict): - my_weights = [opts.tasks[corpus_id]['weight'] for corpus_id in my_corpus_ids] - my_introduce_at_training_step = [ - opts.tasks[corpus_id]['introduce_at_training_step'] for corpus_id in my_corpus_ids - ] - return cls(my_corpus_ids, my_weights, my_introduce_at_training_step) - - def sample_corpus_ids( - self, - n_samples: int, - communication_batch_id: int, - ): - weights = [ - weight if introduce_at_training_step <= communication_batch_id else 0 - for (corpus_id, weight, introduce_at_training_step) in zip( - self.my_corpus_ids, self.my_weights, self.my_introduce_at_training_step - ) - ] - sum_w = sum(weights) - assert sum_w > 0 - p = [weight / sum_w for weight in weights] - # sampling with replacement from weighted corpora (language pairs) - sampled_corpus_ids = np.random.choice(self.my_corpus_ids, size=n_samples, p=p) - return sampled_corpus_ids - - -class RoundRobinTaskDistributionStrategy(TaskDistributionStrategy): - """ - Schedules tasks (corpora) in a round-robin fashion. - Yields a communication batch of n_samples at a time. - When reaching the end of the list of tasks, starts over from the beginning. - """ - - def __init__(self, my_corpus_ids: List[str]): - self.infinite_corpus_ids = cycle(my_corpus_ids) - - @classmethod - def from_opts(cls, my_corpus_ids: List[str], opts: dict): - return cls(my_corpus_ids) - - def sample_corpus_ids( - self, - n_samples: int, - communication_batch_id: int, - ): - return list(islice(self.infinite_corpus_ids, n_samples)) - - -TASK_DISTRIBUTION_STRATEGIES = { - 'weighted_sampling': WeightedSamplingTaskDistributionStrategy, - 'roundrobin': RoundRobinTaskDistributionStrategy, -} - DatasetMetadata = namedtuple( 'DatasetMetadata', 'src_lang tgt_lang encoder_id decoder_id corpus_id encoder_adapter_ids decoder_adapter_ids' @@ -135,6 +44,7 @@ class TaskSpecs(): decoder_id: List[str] corpus_id: str weight: int + introduce_at_training_step: int corpus_opts: dict src_vocab: Any # FIXME: type tgt_vocab: Any @@ -158,6 +68,102 @@ def get_serializable_metadata(self): ) +@dataclass +class BatchTaskSample: + """ + A deterministicly random sample of one task per device, to be trained in a single batch. + """ + # maps from global rank to Task + tasks: Dict[int, TaskSpecs] + training_step: int + + +class TaskDistributionStrategy(ABC): + """ + An abstract task distribution strategy, controls which tasks will be scheduled next. + """ + def __init__(self): + self.training_step = 0 + + @abstractmethod + def sample_corpus_ids(self, active_tasks: Dict[int, List[TaskSpecs]]) -> BatchTaskSample: + """ + Select one task per device, to train on. + active_tasks[global_rank] -> (task_id, weight) + """ + pass + + +# # Sanity check of weights and curriculum +# assert len(self.my_corpus_ids) == len(self.my_weights) +# assert len(self.my_corpus_ids) == len(self.my_introduce_at_training_step) +# if len(self.my_corpus_ids) == 0: +# raise ValueError('No corpora on device') +# if sum(my_weights) <= 0: +# raise ValueError('Can not set "weight" of all corpora on a device to zero') +# if all(x > 0 for x in my_introduce_at_training_step): +# raise ValueError('Can not set "introduce_at_training_step" of all corpora on a device to nonzero') +# if all(weight == 0 or start > 0 for (weight, start) in zip(my_weights, my_introduce_at_training_step)): +# raise ValueError('Invalid curriculum: no corpus is ready to start in the first step') + +class WeightedSamplingTaskDistributionStrategy(TaskDistributionStrategy): + """ + Schedules tasks by sampling with replacement from a categorical distribution. + The probabilities are found by normalizing the weights of all valid tasks (corpora). + Valid tasks are those that are present on this device, and have already reached + their curriculum starting point "introduce_at_training_step". + """ + + def __init__( + self, + seed: int, + ): + super().__init__() + self.rng = np.random.default_rng(seed=seed) + + def sample_corpus_ids(self, active_tasks: Dict[int, List[TaskSpecs]]) -> BatchTaskSample: + result: Dict[int, str] = dict() + for global_rank in sorted(active_tasks.keys()): + tasks, weights = zip(*active_tasks[global_rank]) + sum_w = sum(weights) + assert sum_w > 0 + p = [weight / sum_w for weight in weights] + # sampling with replacement from weighted corpora (language pairs) + sampled_corpus_id = self.rng.choice(tasks, size=1, p=p)[0] + result[global_rank] = sampled_corpus_id + bts = BatchTaskSample(tasks=result, training_step=self.training_step) + self.training_step += 1 + return bts + + +class RoundRobinTaskDistributionStrategy(TaskDistributionStrategy): + """ + Schedules tasks (corpora) in a round-robin fashion. + Yields a communication batch of n_samples at a time. + When reaching the end of the list of tasks, starts over from the beginning. + """ + + def __init__(self, seed: int): + super().__init__(seed) + + def sample_corpus_ids(self, active_tasks: Dict[int, List[TaskSpecs]]) -> BatchTaskSample: + self.training_step += 1 + result: Dict[int, str] = dict() + for global_rank in sorted(active_tasks.keys()): + tasks, _ = zip(*active_tasks[global_rank]) + sampled_corpus_id = tasks[self.training_step % len(tasks)] + result[global_rank] = sampled_corpus_id + bts = BatchTaskSample(tasks=result, training_step=self.training_step) + self.training_step += 1 + return bts + + +TASK_DISTRIBUTION_STRATEGIES = { + 'weighted_sampling': WeightedSamplingTaskDistributionStrategy, + 'roundrobin': RoundRobinTaskDistributionStrategy, +} + + def get_adapter_ids(opts, corpus_opts, side): if 'adapters' not in opts or 'adapters' not in corpus_opts: return [] @@ -253,6 +259,7 @@ def from_opts(cls, opts: Namespace, world_context: WorldContext): if not len(opts.dec_layers) == 1: raise Exception('With more than one decoder stack, you must explictly define dec_sharing_group') + task_distribution_strategy = TASK_DISTRIBUTION_STRATEGIES[opts.task_distribution_strategy]() tasks = [] uses_adapters = False for ( @@ -267,6 +274,7 @@ def from_opts(cls, opts: Namespace, world_context: WorldContext): encoder_id = corpus_opts.get('enc_sharing_group', [src_lang]) decoder_id = corpus_opts.get('dec_sharing_group', [tgt_lang]) weight = corpus_opts.get('weight', 1.0) + introduce_at_training_step = corpus_opts.get('introduce_at_training_step', 0) if 'adapters' in corpus_opts: encoder_adapter_ids = get_adapter_ids(opts, corpus_opts, 'encoder') decoder_adapter_ids = get_adapter_ids(opts, corpus_opts, 'decoder') @@ -283,6 +291,7 @@ def from_opts(cls, opts: Namespace, world_context: WorldContext): decoder_id=decoder_id, corpus_id=corpus_id, weight=weight, + introduce_at_training_step=introduce_at_training_step, corpus_opts=corpus_opts, src_vocab=None, tgt_vocab=None, @@ -294,13 +303,13 @@ def from_opts(cls, opts: Namespace, world_context: WorldContext): tasks, world_context=world_context, accum_count=opts.accum_count, + task_distribution_strategy=task_distribution_strategy, uses_adapters=uses_adapters, ) def global_to_local(self, node_rank, local_rank, opts): assert node_rank is not None assert local_rank is not None - task_distribution_strategy = self._get_strategy(node_rank=node_rank, local_rank=local_rank, opts=opts) device_context = self.world_context.global_to_local(node_rank, local_rank) return LocalTaskQueueManager( self.tasks, @@ -308,27 +317,11 @@ def global_to_local(self, node_rank, local_rank, opts): world_context=self.world_context, components_to_gpus=self.components_to_gpus, components_to_groups=self.components_to_groups, - task_distribution_strategy=task_distribution_strategy, + task_distribution_strategy=self.task_distribution_strategy, uses_adapters=self.uses_adapters, device_context=device_context, ) - def _get_strategy(self, node_rank, local_rank, opts): - assert node_rank is not None - assert local_rank is not None - # Global TQM does not have a task distribution strategy, but the local ones do - my_corpus_ids = [task.corpus_id for task in self._tasks_on_device(node_rank, local_rank)] - try: - strategy = TASK_DISTRIBUTION_STRATEGIES[opts.task_distribution_strategy].from_opts( - my_corpus_ids=my_corpus_ids, - opts=opts, - ) - return strategy - except Exception as e: - raise Exception( - f'Exception when creating task distribution strategy on {node_rank}:{local_rank} {e}' - ) - def __repr__(self): kwargs = ',\n '.join( f'{key}={pformat(self.__getattribute__(key))}' @@ -350,6 +343,15 @@ def _tasks_on_device(self, node_rank, local_rank): def get_all_tasks(self): return self.tasks + def get_active_tasks(self) -> Dict[int, List[TaskSpecs]]: + result = defaultdict(list) + for task in self.tasks: + # TODO: DRY violation, this computation is implemented in many places + global_rank = task.node_rank * self.gpus_per_node + task.local_rank + if task.introduce_at_training_step <= self.task_distribution_strategy.training_step: + result[global_rank].append(task) + return result + @staticmethod def _default_node_gpu(n_tasks, n_nodes, gpus_per_node): def yield_each_gpu(): @@ -360,80 +362,103 @@ def yield_each_gpu(): # yield GPUs in rank order, repeat as necessary return list(islice(cycle(yield_each_gpu()), n_tasks)) - def create_all_distributed_groups( + def create_all_distributed_components( self, + use_attention_bridge: bool, new_group_func=torch.distributed.new_group, ): - if not self.world_context.is_distributed(): - self.components_to_gpus = dict() - self.components_to_groups = dict() - return self.components_to_groups - - # Single OrderedDict contains all components. - # Keys are tuples of strings. - # The length of the key varies depending on the component: - # ('encoder', layer_stack_index, encoder_id) - # ('decoder', layer_stack_index, decoder_id) - # ('src_emb', lang) - # ('tgt_emb', lang) - # ('encoder_adapters', layer_stack_index, encoder_id, adapter_group, sub_id) - # ('decoder_adapters', layer_stack_index, decoder_id, adapter_group, sub_id) - self.components_to_gpus = OrderedDict() - - for node_rank in range(self.n_nodes): - for local_rank in range(self.gpus_per_node): - global_rank = node_rank * self.gpus_per_node + local_rank - tasks = self._tasks_on_device(node_rank, local_rank) - - for task in tasks: - keys = [ - ('src_emb', task.src_lang), - ('tgt_emb', task.tgt_lang), - ] - for layer_stack_index, encoder_id in enumerate(task.encoder_id): - keys.append(('encoder', layer_stack_index, encoder_id)) - for layer_stack_index, decoder_id in enumerate(task.decoder_id): - keys.append(('decoder', layer_stack_index, decoder_id)) - for key in keys: - # Using setdefault to treat OrderedDict as defaultdict - self.components_to_gpus.setdefault(key, set()).add(global_rank) - - if task.encoder_adapter_ids: - for layer_stack_index, adapter_group, sub_id in task.encoder_adapter_ids: - encoder_id = task.encoder_id[layer_stack_index] - key = ('encoder_adapters', layer_stack_index, encoder_id, adapter_group, sub_id) - self.components_to_gpus.setdefault(key, set()).add(global_rank) - if task.decoder_adapter_ids: - for layer_stack_index, adapter_group, sub_id in task.decoder_adapter_ids: - decoder_id = task.decoder_id[layer_stack_index] - key = ('decoder_adapters', layer_stack_index, decoder_id, adapter_group, sub_id) - self.components_to_gpus.setdefault(key, set()).add(global_rank) - - # Structured, each component in a separate OrderedDict - self.components_to_groups = { - component_type: OrderedDict() for component_type - in ('encoder', 'decoder', 'src_emb', 'tgt_emb') - } - if self.uses_adapters: - self.components_to_groups['encoder_adapters'] = OrderedDict() - self.components_to_groups['decoder_adapters'] = OrderedDict() - for key, global_ranks in self.components_to_gpus.items(): - if len(global_ranks) < 2: - # only create a process group if the component is on 2 or more gpus - continue - sorted_global_ranks = list(sorted(global_ranks)) - min_rank = sorted_global_ranks[0] - # The torch.distributed.new_group function requires that all - # processes in the main group (i.e. all processes that are part of - # the distributed job) enter the function, even if they are not - # going to be members of the group. Additionally, groups should be - # created in the same order in all processes. - group_tpl = (min_rank, new_group_func(sorted_global_ranks)) - component_type = key[0] - component_id = key[1:] - self.components_to_groups.setdefault(component_type, OrderedDict())[component_id] = group_tpl - - return self.components_to_groups + """ + Creates DistributedComponent objects. + For all components that are on more than one device, creats a communication group. + """ + builder = DistributedComponentBuilder() + for task in self.tasks: + # TODO: DRY violation, this computation is implemented in many places + global_rank = task.node_rank * self.gpus_per_node + task.local_rank + builder.add( + DistributedEmbedding( + global_ranks={global_rank}, + group=None, + side=Side.encoder, + lang=task.src_lang, + ) + ) + builder.add( + DistributedEmbedding( + global_ranks={global_rank}, + group=None, + side=Side.decoder, + lang=task.tgt_lang, + ) + ) + builder.add( + DistributedGenerator( + global_ranks={global_rank}, + group=None, + lang=task.tgt_lang, + ) + ) + for layer_stack_index, encoder_id in enumerate(task.encoder_id): + builder.add( + DistributedEncoder( + global_ranks={global_rank}, + group=None, + layer_stack_index=layer_stack_index, + xcoder_id=encoder_id, + ) + ) + for layer_stack_index, decoder_id in enumerate(task.decoder_id): + builder.add( + DistributedDecoder( + global_ranks={global_rank}, + group=None, + layer_stack_index=layer_stack_index, + xcoder_id=decoder_id, + ) + ) + if task.encoder_adapter_ids: + for layer_stack_index, adapter_group, sub_id in task.encoder_adapter_ids: + builder.add( + DistributedAdapter( + global_ranks={global_rank}, + group=None, + side=Side.encoder, + layer_stack_index=layer_stack_index, + adapter_group=adapter_group, + sub_id=sub_id, + ) + ) + if task.decoder_adapter_ids: + for layer_stack_index, adapter_group, sub_id in task.decoder_adapter_ids: + builder.add( + DistributedAdapter( + global_ranks={global_rank}, + group=None, + side=Side.decoder, + layer_stack_index=layer_stack_index, + adapter_group=adapter_group, + sub_id=sub_id, + ) + ) + if use_attention_bridge: + builder.add(DistributedAttentionBridge(global_ranks={global_rank}, group=None)) + + # once all DistributedComponents are created, we can initialize communication groups + if self.world_context.is_distributed(): + for component in builder: + # do not create communication groups for components on a single device + if len(component.global_ranks) > 1: + # The torch.distributed.new_group function requires that all + # processes in the main group (i.e. all processes that are part of + # the distributed job) enter the function, even if they are not + # going to be members of the group. Additionally, groups should be + # created in the same order in all processes. + component.group = new_group_func(sorted(component.global_ranks)) + else: + logger.info(f'{component.get_name()} is on a single device') + + sorted_components = list(builder) + return sorted_components def get_langs(self, side): if side == 'src': @@ -452,7 +477,7 @@ def __init__( world_context: WorldContext, components_to_gpus=None, components_to_groups=None, - task_distribution_strategy: Optional[TaskDistributionStrategy] = None, + task_distribution_strategy: TaskDistributionStrategy = None, uses_adapters: bool = False, device_context: Optional[DeviceContext] = None, ): @@ -594,14 +619,10 @@ def get_my_grouped_components(self, model): return my_grouped_components - def sample_corpus_ids(self, communication_batch_id: int): - corpus_id = self.task_distribution_strategy.sample_corpus_ids( - 1, - communication_batch_id, - )[0] - corpus_ids = [corpus_id for _ in range(self.accum_count)] - self.sampled_task_counts.update(corpus_ids) - return corpus_ids + def sample_corpus_ids(self): + active_tasks: Dict[int, List[TaskSpecs]] = self.get_active_tasks() + batch_task_sample = self.task_distribution_strategy.sample_corpus_ids(active_tasks) + return batch_task_sample def get_my_encoders(self, layer_stack_index: int): my_encoder_ids = [task.encoder_id[layer_stack_index] for task in self.get_my_tasks()] diff --git a/mammoth/modules/layer_stack_encoder.py b/mammoth/modules/layer_stack_encoder.py index 621c7c42..7ff401da 100644 --- a/mammoth/modules/layer_stack_encoder.py +++ b/mammoth/modules/layer_stack_encoder.py @@ -1,5 +1,5 @@ from torch import nn -from typing import Dict, List +from typing import Dict, List, Optional from mammoth.modules.encoder import EncoderBase from mammoth.models.adapters import Adapter, AdaptedTransformerEncoder @@ -120,9 +120,12 @@ def n_layer_stacks(self): def get_submodule(self, layer_stack_index: int, module_id: str): return self.encoders[layer_stack_index][module_id] - def get_adapter(self, module_id: str, adapter_group: str, sub_id: str): + def get_adapter(self, adapter_group: str, sub_id: str): name = Adapter._name(adapter_group, sub_id) layer_stack_index = self._adapter_to_stack[name] + # All module_ids in the same slot (should) have the same adapters. + # Thus, we can select one arbitrarily. + module_id = sorted(self.encoders[layer_stack_index].keys())[0] return self.encoders[layer_stack_index][module_id].get_adapter(adapter_group, sub_id) def add_adapter(