Skip to content

Commit

Permalink
ignore host-device syncs
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Oct 29, 2024
1 parent 78b0a2e commit 7ec88be
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 8 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ dependencies = [
"omegaconf",
"safetensors",
"importlib_resources",
"ai2-olmo-eval==0.1.0",
"ai2-olmo-eval==0.2.0",
]

[project.urls]
Expand Down
23 changes: 16 additions & 7 deletions src/olmo_core/train/callbacks/evaluator_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@
from olmo_core.eval import Evaluator
from olmo_core.eval.lm_evaluator import LMEvaluator
from olmo_core.exceptions import OLMoConfigurationError
from olmo_core.utils import format_float, get_default_device, move_to_device
from olmo_core.utils import (
cuda_sync_debug_mode,
format_float,
get_default_device,
move_to_device,
)

from ..common import Duration
from .callback import Callback, CallbackConfig
Expand Down Expand Up @@ -71,7 +76,10 @@ def post_step(self):
logits, ce_loss, _ = self.trainer.eval_batch(
batch, loss_reduction="none", compute_z_loss=False
)
evaluator.update_metrics(batch, ce_loss, logits)

# NOTE: might have host-device syncs here but that's okay.
with cuda_sync_debug_mode(0):
evaluator.update_metrics(batch, ce_loss, logits)

if eval_step % self.trainer.cancel_check_interval == 0:
self.trainer.check_if_canceled()
Expand All @@ -89,10 +97,11 @@ def post_step(self):
# NOTE: going to have a host-device sync here but that's okay. It's only once
# per evaluator.
metrics = []
for name, value in evaluator.compute_metrics().items():
value = value.item()
metrics.append(f" {name}={format_float(value)}")
self.trainer.record_metric(f"eval/{evaluator.name}/{name}", value)
with cuda_sync_debug_mode(0):
for name, value in evaluator.compute_metrics().items():
value = value.item()
metrics.append(f" {name}={format_float(value)}")
self.trainer.record_metric(f"eval/{evaluator.name}/{name}", value)
log.info("Eval metrics:\n" + "\n".join(metrics))

# Restore model to train mode.
Expand Down Expand Up @@ -180,7 +189,7 @@ def __init__(
rank_batch_size_instances = max(0, rank_batch_size // self.task.max_sequence_length)
log.info(
f"Using per-rank batch size of {rank_batch_size_instances} instances "
f"for downstream eval task '{task}'"
f"for downstream eval task '{task}' with max sequence length {self.task.max_sequence_length:,d} tokens"
)

data_loader = DataLoader(
Expand Down
18 changes: 18 additions & 0 deletions src/olmo_core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import time
import uuid
import warnings
from contextlib import contextmanager
from datetime import datetime
from itertools import cycle, islice
from queue import Queue
Expand Down Expand Up @@ -608,3 +609,20 @@ def add_sub_dict(prefix: str, sub_dict: Dict[str, Any]):
out[k] = v

return out


@contextmanager
def cuda_sync_debug_mode(debug_mode: Union[int, str]):
"""
A context manager for temporarily setting the CUDA sync debug mode.
"""
current_mode: Optional[int] = None

try:
if torch.cuda.is_available():
current_mode = torch.cuda.get_sync_debug_mode()
torch.cuda.set_sync_debug_mode(debug_mode)
yield
finally:
if current_mode is not None:
torch.cuda.set_sync_debug_mode(debug_mode)

0 comments on commit 7ec88be

Please sign in to comment.