From 98366a3f4e6151fa8457697173b04d58b35531ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stig-Arne=20Gr=C3=B6nroos?= Date: Mon, 26 Feb 2024 16:45:58 +0200 Subject: [PATCH] Training with determistic task sampling now works on CPU - In order to simplify, CPU and single-GPU now also use multiprocessing, placing the dataloader in a separate process. Several places have been refactored to use the new distributed components: - Init broadcast in train_single.py - Gradient communication in trainer.py (still uses only_ready_reduce_and_rescale_grads though) - Sub-optimizer construction in utils/optimizers.py Two places have been identified as potential future candidates: - model_builder - module splitter for checkpointing The task distribution is now logged in the TQM of the main rank. It no longer has to be done in the dataloader, as each TQM has a global view. --- mammoth/bin/train.py | 136 ++++++++++++----------- mammoth/distributed/communication.py | 6 +- mammoth/distributed/components.py | 4 + mammoth/inputters/dataloader.py | 2 +- mammoth/tests/test_task_queue_manager.py | 14 +-- mammoth/train_single.py | 92 ++++----------- mammoth/trainer.py | 69 +++--------- mammoth/utils/optimizers.py | 57 +++------- mammoth/utils/report_manager.py | 10 +- 9 files changed, 142 insertions(+), 248 deletions(-) diff --git a/mammoth/bin/train.py b/mammoth/bin/train.py index 330b6dde..ff83a01e 100644 --- a/mammoth/bin/train.py +++ b/mammoth/bin/train.py @@ -196,85 +196,91 @@ def train(opts): current_env["MASTER_ADDR"] = opts.master_ip current_env["MASTER_PORT"] = str(opts.master_port) node_rank = opts.node_rank - - queues = [] - semaphores = [] - mp = torch.multiprocessing.get_context('spawn') - logger.info("world_size = {}, queue_size = {}".format(opts.world_size, opts.queue_size)) - # Create a thread to listen for errors in the child processes. - error_queue = mp.SimpleQueue() - error_handler = ErrorHandler(error_queue) - # Train with multiprocessing. - procs = [] - producers = [] - - for local_rank in range(world_context.gpus_per_node): + n_local_ranks = world_context.gpus_per_node + else: + n_local_ranks = 1 + + queues = [] + semaphores = [] + mp = torch.multiprocessing.get_context('spawn') + logger.info("world_size = {}, queue_size = {}".format(opts.world_size, opts.queue_size)) + # Create a thread to listen for errors in the child processes. + error_queue = mp.SimpleQueue() + error_handler = ErrorHandler(error_queue) + # Train with multiprocessing. + procs = [] + producers = [] + + for local_rank in range(n_local_ranks): + if world_context.context == DeviceContextEnum.MULTI_GPU: device_context: DeviceContext = world_context.global_to_local( node_rank=node_rank, local_rank=local_rank, ) - # This task_queue_manager will only yield the items that are active on this gpu - task_queue_manager = global_task_queue_manager.global_to_local( - node_rank=node_rank, - local_rank=local_rank, - opts=opts - ) - # store rank in env (FIXME: is this obsolete?) current_env["RANK"] = str(device_context.global_rank) current_env["LOCAL_RANK"] = str(device_context.local_rank) - - q = mp.Queue(opts.queue_size) - semaphore = mp.Semaphore(opts.queue_size) - queues.append(q) - semaphores.append(semaphore) - procs.append( - mp.Process( - target=consumer, - args=(train_process, opts, device_context, error_queue, q, semaphore, task_queue_manager), - daemon=True, - ) - ) - procs[local_rank].start() - logger.info(" Starting process pid: %d " % procs[local_rank].pid) - error_handler.add_child(procs[local_rank].pid) - - # Get the iterator to generate from - train_iter = DynamicDatasetIter.from_opts( - task_queue_manager=task_queue_manager, - transforms_cls=transforms_cls, - vocabs_dict=vocabs_dict, - opts=opts, - is_train=True, + else: + node_rank = 0 + local_rank = 0 + device_context: DeviceContext = world_context.global_to_local( + node_rank=0, + local_rank=0, ) - producer = mp.Process( - target=batch_producer, args=(train_iter, q, semaphore, opts, local_rank), daemon=True - ) - producers.append(producer) - producers[local_rank].start() - logger.info(" Starting producer process pid: {} ".format(producers[local_rank].pid)) - error_handler.add_child(producers[local_rank].pid) - - for p in procs: - logger.info("DD logger") - p.join() - # Once training is done, we can terminate the producers - for p in producers: - p.terminate() + # This task_queue_manager will only yield the items that are active on this gpu + # for the consumer (trainer process) + task_queue_manager = global_task_queue_manager.global_to_local( + node_rank=node_rank, + local_rank=local_rank, + opts=opts + ) - else: - # SINGLE_GPU or CPU - device_context: DeviceContext = world_context.global_to_local( - node_rank=0, - local_rank=0, + q = mp.Queue(opts.queue_size) + semaphore = mp.Semaphore(opts.queue_size) + queues.append(q) + semaphores.append(semaphore) + procs.append( + mp.Process( + target=consumer, + args=(train_process, opts, device_context, error_queue, q, semaphore, task_queue_manager), + daemon=True, + ) ) + procs[local_rank].start() + logger.info(" Starting process pid: %d " % procs[local_rank].pid) + error_handler.add_child(procs[local_rank].pid) + + # This task_queue_manager will only yield the items that are active on this gpu + # for the producer (dataloader process) task_queue_manager = global_task_queue_manager.global_to_local( - node_rank=0, - local_rank=0, + node_rank=node_rank, + local_rank=local_rank, opts=opts ) - train_process(opts, device_context=device_context, task_queue_manager=task_queue_manager) + # Get the iterator to generate from + train_iter = DynamicDatasetIter.from_opts( + task_queue_manager=task_queue_manager, + transforms_cls=transforms_cls, + vocabs_dict=vocabs_dict, + opts=opts, + is_train=True, + ) + + producer = mp.Process( + target=batch_producer, args=(train_iter, q, semaphore, opts, local_rank), daemon=True + ) + producers.append(producer) + producers[local_rank].start() + logger.info(" Starting producer process pid: {} ".format(producers[local_rank].pid)) + error_handler.add_child(producers[local_rank].pid) + + for p in procs: + logger.info("DD logger") + p.join() + # Once training is done, we can terminate the producers + for p in producers: + p.terminate() def _get_parser(): diff --git a/mammoth/distributed/communication.py b/mammoth/distributed/communication.py index 687da4b9..c984f068 100644 --- a/mammoth/distributed/communication.py +++ b/mammoth/distributed/communication.py @@ -7,6 +7,7 @@ import torch import torch.distributed +from mammoth.distributed.contexts import DeviceContextEnum from mammoth.utils.logging import init_logger, logger from mammoth.utils.misc import set_random_seed @@ -263,7 +264,8 @@ def consumer(process_fn, opts, device_context, error_queue, batch_queue, semapho f'local_rank {device_context.local_rank}' ) logger.info(f'opts.gpu_ranks {opts.gpu_ranks}') - multi_init(opts, device_context.global_rank) + if device_context.context == DeviceContextEnum.MULTI_GPU: + multi_init(opts, device_context.global_rank) # error_queue not passed (is this intentional?) process_fn( opts, @@ -279,4 +281,4 @@ def consumer(process_fn, opts, device_context, error_queue, batch_queue, semapho # propagate exception to parent process, keeping original traceback import traceback - error_queue.put((opts.gpu_ranks[device_context.node_rank], traceback.format_exc())) + error_queue.put((device_context.node_rank, traceback.format_exc())) diff --git a/mammoth/distributed/components.py b/mammoth/distributed/components.py index 0ca7bb08..53cab89a 100644 --- a/mammoth/distributed/components.py +++ b/mammoth/distributed/components.py @@ -68,6 +68,10 @@ def named_parameters(self, model: NMTModel): def min_rank(self) -> int: return min(self.global_ranks) + def needs_communication(self) -> bool: + # if the component needs communication, a group must be set + return self.group is not None + @dataclass class DistributedXCoder(DistributedComponent, ABC): diff --git a/mammoth/inputters/dataloader.py b/mammoth/inputters/dataloader.py index 542b9396..8fcb87fb 100644 --- a/mammoth/inputters/dataloader.py +++ b/mammoth/inputters/dataloader.py @@ -341,7 +341,7 @@ def __iter__(self): batch_task_sample = self.task_queue_manager.sample_corpus_ids() my_task = batch_task_sample.tasks[self.task_queue_manager.global_rank] ordered_iter, metadata = self.dataset_iterators[my_task.corpus_id] - for _ in self.task_queue_manager.accum_count: + for _ in range(self.task_queue_manager.accum_count): batch = next(ordered_iter) if batch_task_sample.training_step == 0: # De-numericalize a few sentences for debugging diff --git a/mammoth/tests/test_task_queue_manager.py b/mammoth/tests/test_task_queue_manager.py index cdf8fd9d..923d0758 100644 --- a/mammoth/tests/test_task_queue_manager.py +++ b/mammoth/tests/test_task_queue_manager.py @@ -144,15 +144,15 @@ def __call__(self, sorted_global_ranks): use_attention_bridge=False, new_group_func=MockGroup() ) assert all_components == [ + DistributedDecoder( + global_ranks={0, 2}, group='Group 0 with GPU ranks [0, 2]', layer_stack_index=0, xcoder_id='y' + ), + DistributedDecoder(global_ranks={1}, group=None, layer_stack_index=0, xcoder_id='yy'), DistributedEncoder( - global_ranks={0, 1}, group='Group 0 with GPU ranks [0, 1]', layer_stack_index=0, xcoder_id='x' + global_ranks={0, 1}, group='Group 1 with GPU ranks [0, 1]', layer_stack_index=0, xcoder_id='x' ), DistributedEncoder(global_ranks={1}, group=None, layer_stack_index=0, xcoder_id='xx'), DistributedEncoder(global_ranks={2}, group=None, layer_stack_index=0, xcoder_id='xxx'), - DistributedDecoder( - global_ranks={0, 2}, group='Group 1 with GPU ranks [0, 2]', layer_stack_index=0, xcoder_id='y' - ), - DistributedDecoder(global_ranks={1}, group=None, layer_stack_index=0, xcoder_id='yy'), DistributedGenerator(global_ranks={0, 2}, group='Group 2 with GPU ranks [0, 2]', lang='b'), DistributedGenerator(global_ranks={1}, group=None, lang='d'), DistributedEmbedding(global_ranks={0, 1}, group='Group 3 with GPU ranks [0, 1]', side=Side.encoder, lang='a'), @@ -183,11 +183,11 @@ def __call__(self, sorted_global_ranks): if component not in all_components: raise Exception(f'my component {component} not in all_components {all_components}') assert my_components == [ + DistributedDecoder(global_ranks={1}, group=None, layer_stack_index=0, xcoder_id='yy'), DistributedEncoder( - global_ranks={0, 1}, group='Group 0 with GPU ranks [0, 1]', layer_stack_index=0, xcoder_id='x' + global_ranks={0, 1}, group='Group 1 with GPU ranks [0, 1]', layer_stack_index=0, xcoder_id='x' ), DistributedEncoder(global_ranks={1}, group=None, layer_stack_index=0, xcoder_id='xx'), - DistributedDecoder(global_ranks={1}, group=None, layer_stack_index=0, xcoder_id='yy'), DistributedGenerator(global_ranks={1}, group=None, lang='d'), DistributedEmbedding(global_ranks={0, 1}, group='Group 3 with GPU ranks [0, 1]', side=Side.encoder, lang='a'), DistributedEmbedding(global_ranks={1}, group=None, side=Side.encoder, lang='c'), diff --git a/mammoth/train_single.py b/mammoth/train_single.py index da972d3d..51e914c9 100644 --- a/mammoth/train_single.py +++ b/mammoth/train_single.py @@ -57,50 +57,17 @@ def _build_valid_iter(opts, vocabs_dict, transforms_cls, task_queue_manager): def init_distributed(model, task_queue_manager): - my_component_groups = task_queue_manager.get_my_distributed_groups() - for (layer_stack_index, encoder_id), (min_rank, group) in my_component_groups['encoder'].items(): - weights = [ - p.data for name, p - in model.encoder.get_submodule(layer_stack_index, encoder_id).named_parameters() - if 'embeddings' not in name and 'adapter' not in name - ] - broadcast_tensors(weights, src=min_rank, group=group) - - for (layer_stack_index, decoder_id), (min_rank, group) in my_component_groups['decoder'].items(): - weights = [ - p.data for name, p - in model.decoder.get_submodule(layer_stack_index, decoder_id).named_parameters() - if 'embeddings' not in name and 'adapter' not in name - ] - broadcast_tensors(weights, src=min_rank, group=group) - - for (src_lang,), (min_rank, group) in my_component_groups['src_emb'].items(): - embs = model.encoder.embeddings[f'embeddings_{src_lang}'] - weights = [p.data for p in embs.parameters()] - broadcast_tensors(weights, src=min_rank, group=group) - - for (tgt_lang,), (min_rank, group) in my_component_groups['tgt_emb'].items(): - embs = model.decoder.embeddings[f'embeddings_{tgt_lang}'] - weights = [p.data for p in embs.parameters()] - broadcast_tensors(weights, src=min_rank, group=group) - - weights = [p.data for p in model.generator[f'generator_{tgt_lang}'].parameters()] - broadcast_tensors(weights, src=min_rank, group=group) - - for adapter_id, (min_rank, group) in my_component_groups['encoder_adapters'].items(): - layer_stack_index, encoder_id, adapter_group, sub_id = adapter_id - adapter = model.encoder.get_submodule(layer_stack_index, encoder_id).get_adapter(adapter_group, sub_id) - weights = [p.data for name, p in adapter.named_parameters()] - broadcast_tensors(weights, src=min_rank, group=group) - - for adapter_id, (min_rank, group) in my_component_groups['decoder_adapters'].items(): - layer_stack_index, decoder_id, adapter_group, sub_id = adapter_id - adapter = model.decoder.get_submodule(layer_stack_index, decoder_id).get_adapter(adapter_group, sub_id) - weights = [p.data for name, p in adapter.named_parameters()] - broadcast_tensors(weights, src=min_rank, group=group) - - weights = [p.data for p in model.attention_bridge.parameters()] - broadcast_tensors(weights, src=0) + # All components on device, in consistent order across devices + my_components = task_queue_manager.get_my_distributed_components() + # Omit components not found elsewhere, as these don't need to be communicated + components_to_communicate = [ + component for component in my_components + if component.needs_communication() + ] + + for component in components_to_communicate: + weights = [p.data for name, p in component.named_parameters(model)] + broadcast_tensors(weights, src=component.min_rank, group=component.group) logger.debug('After init_distributed') for name, p in model.named_parameters(): @@ -134,6 +101,8 @@ def main( checkpoint = None model_opts = _get_model_opts(opts, checkpoint=checkpoint) + task_queue_manager.create_all_distributed_components(use_attention_bridge=model_opts.bridge) + # Build model. model, generators_md = build_model(model_opts, opts, vocabs_dict, task_queue_manager, checkpoint) @@ -141,9 +110,6 @@ def main( logger.info("{} - Init model".format(device_context.id)) if device_context.is_distributed(): init_distributed(model, task_queue_manager) - else: - # Initialize some data structures - _ = task_queue_manager.get_my_distributed_groups() enc, dec = model.count_parameters(log=logger.debug) logger.info("{} - total encoder parameters: {}".format(device_context.id, enc)) logger.info("{} - total decoder parameters: {}".format(device_context.id, dec)) @@ -173,30 +139,18 @@ def main( ) logger.info("{} - Trainer built".format(device_context.id)) - if batch_queue is None: - train_iter = DynamicDatasetIter.from_opts( - task_queue_manager=task_queue_manager, - transforms_cls=transforms_cls, - vocabs_dict=vocabs_dict, - opts=opts, - is_train=True, - ) - # TODO: check that IterOnDevice is unnecessary here; corpora should be already on device - # if device_context.is_gpu(): - # train_iter = IterOnDevice(_train_iter, device_context.local_rank) - # else: - # train_iter = IterOnDevice(_train_iter, -1) - else: - assert semaphore is not None, "Using batch_queue requires semaphore as well" + # It is no longer possible to train without multiprocessing + assert batch_queue is not None + assert semaphore is not None - def _train_iter(): - while True: - batch, metadata, communication_batch_id = batch_queue.get() - semaphore.release() - # TODO: confirm that batch-providing corpus has already been to'd to the correct place - yield batch, metadata, communication_batch_id + def _train_iter(): + while True: + batch, metadata, communication_batch_id = batch_queue.get() + semaphore.release() + # TODO: confirm that batch-providing corpus has already been to'd to the correct place + yield batch, metadata, communication_batch_id - train_iter = _train_iter() + train_iter = _train_iter() # train_iter = iter_on_device(train_iter, device_context) logger.info("Device {} - Valid iter".format(device_context.id)) valid_iter = _build_valid_iter(opts, vocabs_dict, transforms_cls, task_queue_manager) diff --git a/mammoth/trainer.py b/mammoth/trainer.py index 5b420764..80390da7 100644 --- a/mammoth/trainer.py +++ b/mammoth/trainer.py @@ -190,13 +190,6 @@ def __init__( self.dropout_steps = dropout_steps self.task_queue_manager = task_queue_manager - my_component_groups = self.task_queue_manager.get_my_distributed_groups() - self.my_encoder_groups = my_component_groups['encoder'] - self.my_decoder_groups = my_component_groups['decoder'] - self.my_src_emb_groups = my_component_groups['src_emb'] - self.my_tgt_emb_groups = my_component_groups['tgt_emb'] - self.my_encoder_adapter_groups = my_component_groups['encoder_adapters'] - self.my_decoder_adapter_groups = my_component_groups['decoder_adapters'] for i in range(len(self.accum_count_l)): assert self.accum_count_l[i] > 0 @@ -277,6 +270,7 @@ def train( batches_with_meta = islice(train_iter, self.accum_count) batch_task_sample = self.task_queue_manager.sample_corpus_ids() + logger.info(f'batch_task_sample has {batch_task_sample.training_step}') my_task = batch_task_sample.tasks[self.task_queue_manager.global_rank] self._gradient_accumulation( @@ -286,54 +280,17 @@ def train( my_task, ) - # Note that all group ids are tuples, some with length 1 - for (layer_stack_index, encoder_id), (_, group) in self.my_encoder_groups.items(): - params = [ - (name, p) for (name, p) - in self.model.encoder.get_submodule(layer_stack_index, encoder_id).named_parameters() - if 'embeddings' not in name and 'adapter' not in name - ] - mammoth.distributed.only_ready_reduce_and_rescale_grads(params, group=group) - - for (layer_stack_index, decoder_id), (_, group) in self.my_decoder_groups.items(): - params = [ - (name, p) for (name, p) - in self.model.decoder.get_submodule(layer_stack_index, decoder_id).named_parameters() - if 'embeddings' not in name and 'adapter' not in name - ] - mammoth.distributed.only_ready_reduce_and_rescale_grads(params, group=group) - - for (src_lang,), (_, group) in self.my_src_emb_groups.items(): - embs = self.model.encoder.embeddings[f'embeddings_{src_lang}'] - mammoth.distributed.only_ready_reduce_and_rescale_grads(embs.named_parameters(), group=group) - - for (tgt_lang,), (_, group) in self.my_tgt_emb_groups.items(): - embs = self.model.decoder.embeddings[f'embeddings_{tgt_lang}'] - mammoth.distributed.only_ready_reduce_and_rescale_grads(embs.named_parameters(), group=group) - - mammoth.distributed.only_ready_reduce_and_rescale_grads( - self.model.generator[f'generator_{tgt_lang}'].named_parameters(), group=group - ) - - for adapter_id, (_, group) in self.my_encoder_adapter_groups.items(): - layer_stack_index, encoder_id, adapter_group, sub_id = adapter_id - adapter = self.model.encoder.get_submodule(layer_stack_index, encoder_id).get_adapter( - adapter_group, sub_id - ) - mammoth.distributed.only_ready_reduce_and_rescale_grads(adapter.named_parameters(), group=group) + # All components on device, in consistent order across devices + my_components = self.task_queue_manager.get_my_distributed_components() + # Omit components not found elsewhere, as these don't need to be communicated + components_to_communicate = [ + component for component in my_components + if component.needs_communication() + ] - for adapter_id, (_, group) in self.my_decoder_adapter_groups.items(): - layer_stack_index, decoder_id, adapter_group, sub_id = adapter_id - adapter = self.model.decoder.get_submodule(layer_stack_index, decoder_id).get_adapter( - adapter_group, sub_id - ) - mammoth.distributed.only_ready_reduce_and_rescale_grads(adapter.named_parameters(), group=group) - - # a group is not specified: reduce across all devices - if device_context.is_distributed(): - mammoth.distributed.only_ready_reduce_and_rescale_grads( - self.model.attention_bridge.named_parameters() - ) + for component in components_to_communicate: + params = component.named_parameters(self.model) + mammoth.distributed.only_ready_reduce_and_rescale_grads(params, group=component.group) self._maybe_update_stats_from_parameters(report_stats, self.model.named_parameters()) @@ -473,8 +430,8 @@ def _gradient_accumulation( for k, (batch, metadata, comm_batch) in enumerate(batches_with_meta): if metadata != expected_metadata: raise Exception( - 'Mismatch in task sampling. ' - f'Received {metadata}, expected {expected_metadata}' + f'Mismatch in task sampling for batch {comm_batch}.\n ' + f'Received {metadata},\n expected {expected_metadata}' ) seen_comm_batches.add(comm_batch) if self.norm_method == "tokens": diff --git a/mammoth/utils/optimizers.py b/mammoth/utils/optimizers.py index f72a658d..b100ad41 100644 --- a/mammoth/utils/optimizers.py +++ b/mammoth/utils/optimizers.py @@ -13,49 +13,20 @@ def attention_bridge_optimizer(model, task_queue_manager, base_optimizer): suboptimizers = {} - my_grouped_components = task_queue_manager.get_my_grouped_components(model) - for component_type in my_grouped_components: - for component_id, component in my_grouped_components[component_type].items(): - if isinstance(component_id, str): - name = component_type + '_' + component_id - else: - name = component_type + '_' + '_'.join([str(x) for x in component_id]) - params = [] - for param_name, param in component.named_parameters(): - if not param.requires_grad: - continue - if 'adapter' in param_name and 'adapter' not in component_type: - # omit adapters from base component optimizers - continue - if 'embedding' in param_name: - print(f'adding {param_name} to suboptimizer {name}') - params.append(param) - if name in suboptimizers: - raise Exception(f'Trying to create second optimizer for "{name}"') - if len(params) != 0: - optimizer = base_optimizer(params) - suboptimizers[name] = optimizer - - for generator_id in task_queue_manager.get_my_generators(): - generator = model.generator[f'generator_{generator_id}'] - params = [] - for name, param in generator.named_parameters(): - if not param.requires_grad: - continue - params.append(param) - optimizer = base_optimizer(params) - suboptimizers[f'generator_{generator_id}'] = optimizer - - attParam = [] - for name, param in model.attention_bridge.named_parameters(): - if not param.requires_grad: - continue - attParam.append(param) - - # skip AB optimizer if AB is not in use - if len(attParam): - optimizer = base_optimizer(attParam) - suboptimizers["attention_bridge"] = optimizer + # All components on device, in consistent order across devices + my_components = task_queue_manager.get_my_distributed_components() + # Also keeping components that are on a single device + for component in my_components: + name = component.get_name() + params = [ + param for param_name, param in component.named_parameters(model) + if param.requires_grad + ] + if name in suboptimizers: + raise Exception(f'Trying to create second optimizer for "{name}"') + if len(params) != 0: + optimizer = base_optimizer(params) + suboptimizers[name] = optimizer optimizer = MultipleOptimizer(suboptimizers, None) return optimizer diff --git a/mammoth/utils/report_manager.py b/mammoth/utils/report_manager.py index 3ad2d631..cc3702f1 100644 --- a/mammoth/utils/report_manager.py +++ b/mammoth/utils/report_manager.py @@ -59,7 +59,7 @@ def report_training( patience, report_stats, multigpu=False, - sampled_task_count=None, + sampled_task_counts=None, ): """ This is the user-defined batch-level traing progress @@ -89,7 +89,7 @@ def report_training( learning_rate, patience, report_stats, - sampled_task_count=sampled_task_count + sampled_task_counts=sampled_task_counts ) return mammoth.utils.Statistics() else: @@ -145,7 +145,7 @@ def maybe_log_tensorboard(self, stats, prefix, learning_rate, patience, step): if self.tensorboard_writer is not None: stats.log_tensorboard(prefix, self.tensorboard_writer, learning_rate, patience, step) - def _report_training(self, step, num_steps, learning_rate, patience, report_stats, sampled_task_count): + def _report_training(self, step, num_steps, learning_rate, patience, report_stats, sampled_task_counts): """ See base class method `ReportMgrBase.report_training`. """ @@ -154,9 +154,9 @@ def _report_training(self, step, num_steps, learning_rate, patience, report_stat self.maybe_log_tensorboard(report_stats, "progress", learning_rate, patience, step) report_stats = mammoth.utils.Statistics() - total = sum(sampled_task_count.values()) + total = sum(sampled_task_counts.values()) logger.info(f'Task sampling distribution: (total {total})') - for task, count in sampled_task_count.most_common(): + for task, count in sampled_task_counts.most_common(): logger.info(f'Task: {task}\tcount: {count}\t{100 * count / total} %') return report_stats