From 7ec88bed64b0a1575045d5eafc3be9234cdea7d8 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 29 Oct 2024 11:52:15 -0700 Subject: [PATCH] ignore host-device syncs --- pyproject.toml | 2 +- .../train/callbacks/evaluator_callback.py | 23 +++++++++++++------ src/olmo_core/utils.py | 18 +++++++++++++++ 3 files changed, 35 insertions(+), 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 244dba69..89f40d2d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ dependencies = [ "omegaconf", "safetensors", "importlib_resources", - "ai2-olmo-eval==0.1.0", + "ai2-olmo-eval==0.2.0", ] [project.urls] diff --git a/src/olmo_core/train/callbacks/evaluator_callback.py b/src/olmo_core/train/callbacks/evaluator_callback.py index cfe484d4..a46d3047 100644 --- a/src/olmo_core/train/callbacks/evaluator_callback.py +++ b/src/olmo_core/train/callbacks/evaluator_callback.py @@ -11,7 +11,12 @@ from olmo_core.eval import Evaluator from olmo_core.eval.lm_evaluator import LMEvaluator from olmo_core.exceptions import OLMoConfigurationError -from olmo_core.utils import format_float, get_default_device, move_to_device +from olmo_core.utils import ( + cuda_sync_debug_mode, + format_float, + get_default_device, + move_to_device, +) from ..common import Duration from .callback import Callback, CallbackConfig @@ -71,7 +76,10 @@ def post_step(self): logits, ce_loss, _ = self.trainer.eval_batch( batch, loss_reduction="none", compute_z_loss=False ) - evaluator.update_metrics(batch, ce_loss, logits) + + # NOTE: might have host-device syncs here but that's okay. + with cuda_sync_debug_mode(0): + evaluator.update_metrics(batch, ce_loss, logits) if eval_step % self.trainer.cancel_check_interval == 0: self.trainer.check_if_canceled() @@ -89,10 +97,11 @@ def post_step(self): # NOTE: going to have a host-device sync here but that's okay. It's only once # per evaluator. metrics = [] - for name, value in evaluator.compute_metrics().items(): - value = value.item() - metrics.append(f" {name}={format_float(value)}") - self.trainer.record_metric(f"eval/{evaluator.name}/{name}", value) + with cuda_sync_debug_mode(0): + for name, value in evaluator.compute_metrics().items(): + value = value.item() + metrics.append(f" {name}={format_float(value)}") + self.trainer.record_metric(f"eval/{evaluator.name}/{name}", value) log.info("Eval metrics:\n" + "\n".join(metrics)) # Restore model to train mode. @@ -180,7 +189,7 @@ def __init__( rank_batch_size_instances = max(0, rank_batch_size // self.task.max_sequence_length) log.info( f"Using per-rank batch size of {rank_batch_size_instances} instances " - f"for downstream eval task '{task}'" + f"for downstream eval task '{task}' with max sequence length {self.task.max_sequence_length:,d} tokens" ) data_loader = DataLoader( diff --git a/src/olmo_core/utils.py b/src/olmo_core/utils.py index cc049de9..a4389629 100644 --- a/src/olmo_core/utils.py +++ b/src/olmo_core/utils.py @@ -7,6 +7,7 @@ import time import uuid import warnings +from contextlib import contextmanager from datetime import datetime from itertools import cycle, islice from queue import Queue @@ -608,3 +609,20 @@ def add_sub_dict(prefix: str, sub_dict: Dict[str, Any]): out[k] = v return out + + +@contextmanager +def cuda_sync_debug_mode(debug_mode: Union[int, str]): + """ + A context manager for temporarily setting the CUDA sync debug mode. + """ + current_mode: Optional[int] = None + + try: + if torch.cuda.is_available(): + current_mode = torch.cuda.get_sync_debug_mode() + torch.cuda.set_sync_debug_mode(debug_mode) + yield + finally: + if current_mode is not None: + torch.cuda.set_sync_debug_mode(debug_mode)