Skip to content

Commit

Permalink
Training with determistic task sampling now works on CPU
Browse files Browse the repository at this point in the history
- In order to simplify, CPU and single-GPU now also use multiprocessing,
  placing the dataloader in a separate process.

Several places have been refactored to use the new distributed
components:
- Init broadcast in train_single.py
- Gradient communication in trainer.py
  (still uses only_ready_reduce_and_rescale_grads though)
- Sub-optimizer construction in utils/optimizers.py

Two places have been identified as potential future candidates:
- model_builder
- module splitter for checkpointing

The task distribution is now logged in the TQM of the main rank.
It no longer has to be done in the dataloader, as each TQM has a global
view.
  • Loading branch information
Waino committed Feb 26, 2024
1 parent 3ec9f22 commit 98366a3
Show file tree
Hide file tree
Showing 9 changed files with 142 additions and 248 deletions.
136 changes: 71 additions & 65 deletions mammoth/bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,85 +196,91 @@ def train(opts):
current_env["MASTER_ADDR"] = opts.master_ip
current_env["MASTER_PORT"] = str(opts.master_port)
node_rank = opts.node_rank

queues = []
semaphores = []
mp = torch.multiprocessing.get_context('spawn')
logger.info("world_size = {}, queue_size = {}".format(opts.world_size, opts.queue_size))
# Create a thread to listen for errors in the child processes.
error_queue = mp.SimpleQueue()
error_handler = ErrorHandler(error_queue)
# Train with multiprocessing.
procs = []
producers = []

for local_rank in range(world_context.gpus_per_node):
n_local_ranks = world_context.gpus_per_node
else:
n_local_ranks = 1

queues = []
semaphores = []
mp = torch.multiprocessing.get_context('spawn')
logger.info("world_size = {}, queue_size = {}".format(opts.world_size, opts.queue_size))
# Create a thread to listen for errors in the child processes.
error_queue = mp.SimpleQueue()
error_handler = ErrorHandler(error_queue)
# Train with multiprocessing.
procs = []
producers = []

for local_rank in range(n_local_ranks):
if world_context.context == DeviceContextEnum.MULTI_GPU:
device_context: DeviceContext = world_context.global_to_local(
node_rank=node_rank,
local_rank=local_rank,
)
# This task_queue_manager will only yield the items that are active on this gpu
task_queue_manager = global_task_queue_manager.global_to_local(
node_rank=node_rank,
local_rank=local_rank,
opts=opts
)

# store rank in env (FIXME: is this obsolete?)
current_env["RANK"] = str(device_context.global_rank)
current_env["LOCAL_RANK"] = str(device_context.local_rank)

q = mp.Queue(opts.queue_size)
semaphore = mp.Semaphore(opts.queue_size)
queues.append(q)
semaphores.append(semaphore)
procs.append(
mp.Process(
target=consumer,
args=(train_process, opts, device_context, error_queue, q, semaphore, task_queue_manager),
daemon=True,
)
)
procs[local_rank].start()
logger.info(" Starting process pid: %d " % procs[local_rank].pid)
error_handler.add_child(procs[local_rank].pid)

# Get the iterator to generate from
train_iter = DynamicDatasetIter.from_opts(
task_queue_manager=task_queue_manager,
transforms_cls=transforms_cls,
vocabs_dict=vocabs_dict,
opts=opts,
is_train=True,
else:
node_rank = 0
local_rank = 0
device_context: DeviceContext = world_context.global_to_local(
node_rank=0,
local_rank=0,
)

producer = mp.Process(
target=batch_producer, args=(train_iter, q, semaphore, opts, local_rank), daemon=True
)
producers.append(producer)
producers[local_rank].start()
logger.info(" Starting producer process pid: {} ".format(producers[local_rank].pid))
error_handler.add_child(producers[local_rank].pid)

for p in procs:
logger.info("DD logger")
p.join()
# Once training is done, we can terminate the producers
for p in producers:
p.terminate()
# This task_queue_manager will only yield the items that are active on this gpu
# for the consumer (trainer process)
task_queue_manager = global_task_queue_manager.global_to_local(
node_rank=node_rank,
local_rank=local_rank,
opts=opts
)

else:
# SINGLE_GPU or CPU
device_context: DeviceContext = world_context.global_to_local(
node_rank=0,
local_rank=0,
q = mp.Queue(opts.queue_size)
semaphore = mp.Semaphore(opts.queue_size)
queues.append(q)
semaphores.append(semaphore)
procs.append(
mp.Process(
target=consumer,
args=(train_process, opts, device_context, error_queue, q, semaphore, task_queue_manager),
daemon=True,
)
)
procs[local_rank].start()
logger.info(" Starting process pid: %d " % procs[local_rank].pid)
error_handler.add_child(procs[local_rank].pid)

# This task_queue_manager will only yield the items that are active on this gpu
# for the producer (dataloader process)
task_queue_manager = global_task_queue_manager.global_to_local(
node_rank=0,
local_rank=0,
node_rank=node_rank,
local_rank=local_rank,
opts=opts
)
train_process(opts, device_context=device_context, task_queue_manager=task_queue_manager)
# Get the iterator to generate from
train_iter = DynamicDatasetIter.from_opts(
task_queue_manager=task_queue_manager,
transforms_cls=transforms_cls,
vocabs_dict=vocabs_dict,
opts=opts,
is_train=True,
)

producer = mp.Process(
target=batch_producer, args=(train_iter, q, semaphore, opts, local_rank), daemon=True
)
producers.append(producer)
producers[local_rank].start()
logger.info(" Starting producer process pid: {} ".format(producers[local_rank].pid))
error_handler.add_child(producers[local_rank].pid)

for p in procs:
logger.info("DD logger")
p.join()
# Once training is done, we can terminate the producers
for p in producers:
p.terminate()


def _get_parser():
Expand Down
6 changes: 4 additions & 2 deletions mammoth/distributed/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
import torch.distributed

from mammoth.distributed.contexts import DeviceContextEnum
from mammoth.utils.logging import init_logger, logger
from mammoth.utils.misc import set_random_seed

Expand Down Expand Up @@ -263,7 +264,8 @@ def consumer(process_fn, opts, device_context, error_queue, batch_queue, semapho
f'local_rank {device_context.local_rank}'
)
logger.info(f'opts.gpu_ranks {opts.gpu_ranks}')
multi_init(opts, device_context.global_rank)
if device_context.context == DeviceContextEnum.MULTI_GPU:
multi_init(opts, device_context.global_rank)
# error_queue not passed (is this intentional?)
process_fn(
opts,
Expand All @@ -279,4 +281,4 @@ def consumer(process_fn, opts, device_context, error_queue, batch_queue, semapho
# propagate exception to parent process, keeping original traceback
import traceback

error_queue.put((opts.gpu_ranks[device_context.node_rank], traceback.format_exc()))
error_queue.put((device_context.node_rank, traceback.format_exc()))
4 changes: 4 additions & 0 deletions mammoth/distributed/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ def named_parameters(self, model: NMTModel):
def min_rank(self) -> int:
return min(self.global_ranks)

def needs_communication(self) -> bool:
# if the component needs communication, a group must be set
return self.group is not None


@dataclass
class DistributedXCoder(DistributedComponent, ABC):
Expand Down
2 changes: 1 addition & 1 deletion mammoth/inputters/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def __iter__(self):
batch_task_sample = self.task_queue_manager.sample_corpus_ids()
my_task = batch_task_sample.tasks[self.task_queue_manager.global_rank]
ordered_iter, metadata = self.dataset_iterators[my_task.corpus_id]
for _ in self.task_queue_manager.accum_count:
for _ in range(self.task_queue_manager.accum_count):
batch = next(ordered_iter)
if batch_task_sample.training_step == 0:
# De-numericalize a few sentences for debugging
Expand Down
14 changes: 7 additions & 7 deletions mammoth/tests/test_task_queue_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,15 +144,15 @@ def __call__(self, sorted_global_ranks):
use_attention_bridge=False, new_group_func=MockGroup()
)
assert all_components == [
DistributedDecoder(
global_ranks={0, 2}, group='Group 0 with GPU ranks [0, 2]', layer_stack_index=0, xcoder_id='y'
),
DistributedDecoder(global_ranks={1}, group=None, layer_stack_index=0, xcoder_id='yy'),
DistributedEncoder(
global_ranks={0, 1}, group='Group 0 with GPU ranks [0, 1]', layer_stack_index=0, xcoder_id='x'
global_ranks={0, 1}, group='Group 1 with GPU ranks [0, 1]', layer_stack_index=0, xcoder_id='x'
),
DistributedEncoder(global_ranks={1}, group=None, layer_stack_index=0, xcoder_id='xx'),
DistributedEncoder(global_ranks={2}, group=None, layer_stack_index=0, xcoder_id='xxx'),
DistributedDecoder(
global_ranks={0, 2}, group='Group 1 with GPU ranks [0, 2]', layer_stack_index=0, xcoder_id='y'
),
DistributedDecoder(global_ranks={1}, group=None, layer_stack_index=0, xcoder_id='yy'),
DistributedGenerator(global_ranks={0, 2}, group='Group 2 with GPU ranks [0, 2]', lang='b'),
DistributedGenerator(global_ranks={1}, group=None, lang='d'),
DistributedEmbedding(global_ranks={0, 1}, group='Group 3 with GPU ranks [0, 1]', side=Side.encoder, lang='a'),
Expand Down Expand Up @@ -183,11 +183,11 @@ def __call__(self, sorted_global_ranks):
if component not in all_components:
raise Exception(f'my component {component} not in all_components {all_components}')
assert my_components == [
DistributedDecoder(global_ranks={1}, group=None, layer_stack_index=0, xcoder_id='yy'),
DistributedEncoder(
global_ranks={0, 1}, group='Group 0 with GPU ranks [0, 1]', layer_stack_index=0, xcoder_id='x'
global_ranks={0, 1}, group='Group 1 with GPU ranks [0, 1]', layer_stack_index=0, xcoder_id='x'
),
DistributedEncoder(global_ranks={1}, group=None, layer_stack_index=0, xcoder_id='xx'),
DistributedDecoder(global_ranks={1}, group=None, layer_stack_index=0, xcoder_id='yy'),
DistributedGenerator(global_ranks={1}, group=None, lang='d'),
DistributedEmbedding(global_ranks={0, 1}, group='Group 3 with GPU ranks [0, 1]', side=Side.encoder, lang='a'),
DistributedEmbedding(global_ranks={1}, group=None, side=Side.encoder, lang='c'),
Expand Down
92 changes: 23 additions & 69 deletions mammoth/train_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,50 +57,17 @@ def _build_valid_iter(opts, vocabs_dict, transforms_cls, task_queue_manager):


def init_distributed(model, task_queue_manager):
my_component_groups = task_queue_manager.get_my_distributed_groups()
for (layer_stack_index, encoder_id), (min_rank, group) in my_component_groups['encoder'].items():
weights = [
p.data for name, p
in model.encoder.get_submodule(layer_stack_index, encoder_id).named_parameters()
if 'embeddings' not in name and 'adapter' not in name
]
broadcast_tensors(weights, src=min_rank, group=group)

for (layer_stack_index, decoder_id), (min_rank, group) in my_component_groups['decoder'].items():
weights = [
p.data for name, p
in model.decoder.get_submodule(layer_stack_index, decoder_id).named_parameters()
if 'embeddings' not in name and 'adapter' not in name
]
broadcast_tensors(weights, src=min_rank, group=group)

for (src_lang,), (min_rank, group) in my_component_groups['src_emb'].items():
embs = model.encoder.embeddings[f'embeddings_{src_lang}']
weights = [p.data for p in embs.parameters()]
broadcast_tensors(weights, src=min_rank, group=group)

for (tgt_lang,), (min_rank, group) in my_component_groups['tgt_emb'].items():
embs = model.decoder.embeddings[f'embeddings_{tgt_lang}']
weights = [p.data for p in embs.parameters()]
broadcast_tensors(weights, src=min_rank, group=group)

weights = [p.data for p in model.generator[f'generator_{tgt_lang}'].parameters()]
broadcast_tensors(weights, src=min_rank, group=group)

for adapter_id, (min_rank, group) in my_component_groups['encoder_adapters'].items():
layer_stack_index, encoder_id, adapter_group, sub_id = adapter_id
adapter = model.encoder.get_submodule(layer_stack_index, encoder_id).get_adapter(adapter_group, sub_id)
weights = [p.data for name, p in adapter.named_parameters()]
broadcast_tensors(weights, src=min_rank, group=group)

for adapter_id, (min_rank, group) in my_component_groups['decoder_adapters'].items():
layer_stack_index, decoder_id, adapter_group, sub_id = adapter_id
adapter = model.decoder.get_submodule(layer_stack_index, decoder_id).get_adapter(adapter_group, sub_id)
weights = [p.data for name, p in adapter.named_parameters()]
broadcast_tensors(weights, src=min_rank, group=group)

weights = [p.data for p in model.attention_bridge.parameters()]
broadcast_tensors(weights, src=0)
# All components on device, in consistent order across devices
my_components = task_queue_manager.get_my_distributed_components()
# Omit components not found elsewhere, as these don't need to be communicated
components_to_communicate = [
component for component in my_components
if component.needs_communication()
]

for component in components_to_communicate:
weights = [p.data for name, p in component.named_parameters(model)]
broadcast_tensors(weights, src=component.min_rank, group=component.group)

logger.debug('After init_distributed')
for name, p in model.named_parameters():
Expand Down Expand Up @@ -134,16 +101,15 @@ def main(
checkpoint = None
model_opts = _get_model_opts(opts, checkpoint=checkpoint)

task_queue_manager.create_all_distributed_components(use_attention_bridge=model_opts.bridge)

# Build model.

model, generators_md = build_model(model_opts, opts, vocabs_dict, task_queue_manager, checkpoint)

logger.info("{} - Init model".format(device_context.id))
if device_context.is_distributed():
init_distributed(model, task_queue_manager)
else:
# Initialize some data structures
_ = task_queue_manager.get_my_distributed_groups()
enc, dec = model.count_parameters(log=logger.debug)
logger.info("{} - total encoder parameters: {}".format(device_context.id, enc))
logger.info("{} - total decoder parameters: {}".format(device_context.id, dec))
Expand Down Expand Up @@ -173,30 +139,18 @@ def main(
)
logger.info("{} - Trainer built".format(device_context.id))

if batch_queue is None:
train_iter = DynamicDatasetIter.from_opts(
task_queue_manager=task_queue_manager,
transforms_cls=transforms_cls,
vocabs_dict=vocabs_dict,
opts=opts,
is_train=True,
)
# TODO: check that IterOnDevice is unnecessary here; corpora should be already on device
# if device_context.is_gpu():
# train_iter = IterOnDevice(_train_iter, device_context.local_rank)
# else:
# train_iter = IterOnDevice(_train_iter, -1)
else:
assert semaphore is not None, "Using batch_queue requires semaphore as well"
# It is no longer possible to train without multiprocessing
assert batch_queue is not None
assert semaphore is not None

def _train_iter():
while True:
batch, metadata, communication_batch_id = batch_queue.get()
semaphore.release()
# TODO: confirm that batch-providing corpus has already been to'd to the correct place
yield batch, metadata, communication_batch_id
def _train_iter():
while True:
batch, metadata, communication_batch_id = batch_queue.get()
semaphore.release()
# TODO: confirm that batch-providing corpus has already been to'd to the correct place
yield batch, metadata, communication_batch_id

train_iter = _train_iter()
train_iter = _train_iter()
# train_iter = iter_on_device(train_iter, device_context)
logger.info("Device {} - Valid iter".format(device_context.id))
valid_iter = _build_valid_iter(opts, vocabs_dict, transforms_cls, task_queue_manager)
Expand Down
Loading

0 comments on commit 98366a3

Please sign in to comment.