Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to save and compare sub-module outputs #690

Merged
merged 9 commits into from
Aug 12, 2024
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `model.embedding_layer_norm` configuration option for adding a LN to the embeddings.
- Added `model.emb_init_std` configuration option to override the standard deviation used to initialize the embeddings.
- Added `CosLinearEnvelope` scheduler, which is a pointwise product of a cosine schedule and a linear decay.
- Added ability to save outputs of submodules for debugging purposes.

### Changed

Expand Down
6 changes: 6 additions & 0 deletions olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1218,6 +1218,12 @@ class TrainConfig(BaseConfig):
Path to cache directory of HF datasets saved with `datasets.save_to_disk`.
"""

module_outputs_save_steps: Optional[List[int]] = None
"""
Outputs of model submodules are saved during the provided steps. Submodule outputs
can be compared using `scripts/compare_module_outputs.py`.
"""

@property
def autocast_precision(self) -> torch.dtype:
if self.precision == "amp_bf16":
Expand Down
58 changes: 58 additions & 0 deletions olmo/train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import cProfile
import functools
import gc
import logging
import math
Expand All @@ -20,6 +21,8 @@
import torch
import torch.distributed as dist
import torch.nn.functional as F
import torch.utils
import torch.utils.hooks
import wandb
from packaging import version
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
Expand Down Expand Up @@ -650,6 +653,53 @@ def remove_checkpoint(self, idx: int = 0, checkpoint_type: CheckpointType = Chec
else:
raise NotImplementedError(checkpoint_type)

def _setup_module_output_save_hooks(self, micro_batch_idx: int) -> List[torch.utils.hooks.RemovableHandle]:
if (
self.cfg.module_outputs_save_steps is None
or self.global_step not in self.cfg.module_outputs_save_steps
):
return []

if micro_batch_idx != 0 or get_global_rank() != 0:
# Hook is currently only used on the first microbatch of rank 0
return []

trace_save_folder = Path(self.cfg.save_folder) / f"traces/step{self.global_step}"
if trace_save_folder.exists():
if self.cfg.save_overwrite:
shutil.rmtree(trace_save_folder)
else:
raise OLMoConfigurationError(
f"Attempting to overwrite traces at step {self.global_step} without --save_overwrite"
)
trace_save_folder.mkdir(parents=True)

def trace_outputs_hook(
module_name: str, _: torch.nn.Module, args: Tuple[torch.Tensor, ...], output: torch.Tensor
) -> None:
if len(args) == 0:
log.info("No input args for module %s, output %s", module_name, output)

module_input = args[0] if len(args) > 0 else torch.tensor(())
trace_save_folder = Path(self.cfg.save_folder) / f"traces/step{self.global_step}"
trace_save_folder.mkdir(parents=True, exist_ok=True)

module_occurence_num = 0
while (
module_input_filepath := trace_save_folder / f"{module_name}_{module_occurence_num}_input.pt"
).exists():
module_occurence_num += 1
torch.save(module_input, module_input_filepath)

module_output_filepath = trace_save_folder / f"{module_name}_{module_occurence_num}_output.pt"
torch.save(output, module_output_filepath)

output_hooks = []
for module_name, module in self.model.named_modules(prefix="model"):
output_hooks.append(module.register_forward_hook(functools.partial(trace_outputs_hook, module_name)))

return output_hooks

def get_labels(self, batch: Dict[str, Any]) -> torch.Tensor:
# Labels are just input IDs shifted to the left (first item is ignored).
labels, label_mask, attention_mask, instance_mask = (
Expand Down Expand Up @@ -740,6 +790,10 @@ def train_batch(self, batch: Dict[str, Any]) -> Tuple[torch.Tensor, Optional[tor
if micro_batch_idx != num_micro_batches - 1:
grad_sync_context = self.dist_model.no_sync

# Register output hooks
output_hooks: List[torch.utils.hooks.RemovableHandle] = []
output_hooks += self._setup_module_output_save_hooks(micro_batch_idx)

with grad_sync_context():
with torch.autocast("cuda", enabled=True, dtype=self.cfg.autocast_precision):
# Run forward pass.
Expand All @@ -756,6 +810,10 @@ def train_batch(self, batch: Dict[str, Any]) -> Tuple[torch.Tensor, Optional[tor
# Run backward pass.
loss.backward()

# Remove output hooks
for hook in output_hooks:
hook.remove()

return ce_batch_loss, z_batch_loss

def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) -> Dict[str, float]:
Expand Down
139 changes: 139 additions & 0 deletions scripts/compare_module_outputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import logging
from argparse import ArgumentParser
from pathlib import Path
from typing import List

import torch

logger = logging.getLogger(__name__)


def _get_module_names(checkpoint_traces_folder: Path) -> List[str]:
module_names = []
for trace_file in checkpoint_traces_folder.iterdir():
trace_file_name = trace_file.name
if trace_file_name.endswith("_input.pt"):
module_name = trace_file_name.removesuffix("_input.pt")
elif trace_file_name.endswith("_output.pt"):
module_name = trace_file_name.removesuffix("_output.pt")
else:
logger.warning("Cannot get parameter from file %s, skipping", trace_file_name)

module_names.append(module_name)

return module_names


def compare_module_output(
base_traces_folder: Path,
compare_traces_folder: Path,
module_name: str,
*,
include_non_tensor_outputs: bool = True,
verbose: bool = False,
):
base_module_input_path = base_traces_folder / f"{module_name}_input.pt"
base_module_output_path = base_traces_folder / f"{module_name}_output.pt"
compare_module_input_path = compare_traces_folder / f"{module_name}_input.pt"
compare_module_output_path = compare_traces_folder / f"{module_name}_output.pt"

map_location = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
base_input = torch.load(str(base_module_input_path), map_location=map_location)
compare_input = torch.load(str(compare_module_input_path), map_location=map_location)

if verbose or base_input.dtype != compare_input.dtype:
logger.info("%s input dtypes: %s %s", module_name, base_input.dtype, compare_input.dtype)
if verbose or base_input.shape != compare_input.shape:
logger.info("%s input shapes: %s %s", module_name, base_input.shape, compare_input.shape)
if (norm_diff := torch.linalg.vector_norm((compare_input - base_input).float()).item()) != 0.0 or verbose:
logger.info("%s input norm diff: %.6f", module_name, norm_diff)
if "wte" in module_name:
logger.info(
"%s mis-matching wte elements: %d",
module_name,
torch.sum(torch.logical_not(torch.eq(base_input, compare_input))),
)

base_output = torch.load(str(base_module_output_path), map_location=map_location)
compare_output = torch.load(str(compare_module_output_path), map_location=map_location)

if isinstance(base_output, torch.Tensor):
if verbose or base_output.dtype != compare_output.dtype:
logger.info("%s output dtypes: %s %s", module_name, base_output.dtype, compare_output.dtype)
if (
norm_diff := torch.linalg.vector_norm((compare_output - base_output).float()).item()
) != 0.0 or verbose:
logger.info("%s output norm diff: %.6f", module_name, norm_diff)
elif include_non_tensor_outputs:
logger.info("%s outputs: %s %s", module_name, base_output, compare_output)
else:
if verbose:
logger.info("Base output is type %s, skipping", type(base_output))


def compare_module_outputs(
base_traces_folder: Path,
compare_traces_folder: Path,
*,
include_non_tensor_outputs: bool = True,
verbose: bool = False,
):
base_modules = set(_get_module_names(base_traces_folder))
compare_modules = set(_get_module_names(compare_traces_folder))

base_only_modules = base_modules - compare_modules
if len(base_only_modules) > 0:
logger.info("Base-only modules: %s", ", ".join(base_only_modules))

compare_only_modules = compare_modules - base_modules
if len(compare_only_modules) > 0:
logger.info("Compare-only modules: %s", ", ".join(compare_only_modules))

common_modules = base_modules.intersection(compare_modules)
for module_name in sorted(common_modules):
compare_module_output(
base_traces_folder,
compare_traces_folder,
module_name,
include_non_tensor_outputs=include_non_tensor_outputs,
verbose=verbose,
)


def main():
logging.basicConfig(encoding="utf-8", level=logging.INFO)

parser = ArgumentParser()
parser.add_argument(
"base_model_traces_path",
type=Path,
help="Path where output traces of the base (i.e. reference) model are stored",
)
parser.add_argument(
"compare_model_traces_path",
type=Path,
help="Path where output traces of the compare (a.k.a new, different) model are stored",
)
parser.add_argument(
"--include_non_tensor_outputs",
action="store_true",
dest="include_non_tensor_outputs",
help="If set, compare module outputs that are not tensors",
)
parser.add_argument(
"--verbose",
action="store_true",
help="If set, show extra information",
)

args = parser.parse_args()
compare_module_outputs(
args.base_model_traces_path,
args.compare_model_traces_path,
include_non_tensor_outputs=args.include_non_tensor_outputs,
verbose=args.verbose,
)


if __name__ == "__main__":
main()
Loading