Skip to content

Commit

Permalink
add richer diff functionality between configs
Browse files Browse the repository at this point in the history
  • Loading branch information
kyleclo committed Oct 10, 2024
1 parent f824795 commit 147013f
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 7 deletions.
16 changes: 15 additions & 1 deletion olmo/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
import sys
import time
import warnings
from collections import defaultdict
from datetime import datetime
from enum import Enum
from itertools import cycle, islice
from pathlib import Path
from queue import Queue
from threading import Thread
from typing import Any, Callable, Dict, MutableMapping, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, MutableMapping, Optional, Tuple, Union

import boto3
import botocore.exceptions as boto_exceptions
Expand Down Expand Up @@ -842,6 +843,19 @@ def get_bytes_range(self, index: int, length: int) -> bytes:


def flatten_dict(dictionary, parent_key="", separator=".", include_lists=False):
"""
Flatten a nested dictionary into a single-level dictionary.
Args:
dictionary (dict): The nested dictionary to be flattened.
parent_key (str, optional): The parent key to be prepended to the keys of the flattened dictionary. Defaults to "".
separator (str, optional): The separator to be used between the parent key and the keys of the flattened dictionary. Defaults to ".".
include_lists (bool, optional): Whether to convert lists to dictionaries with integer keys. Defaults to False.
Returns:
dict: The flattened dictionary.
"""
d: Dict[str, Any] = {}
for key, value in dictionary.items():
new_key = parent_key + separator + key if parent_key else key
Expand Down
113 changes: 107 additions & 6 deletions scripts/compare_wandb_configs.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,28 @@
"""
Examples:
Comparing Peteish7 to OLMoE
- python scripts/compare_wandb_configs.py https://wandb.ai/ai2-llm/olmo-medium/runs/cej4ya39 https://wandb.ai/ai2-llm/olmoe/runs/rzsn9tlc
Comparing Peteish7 to Amberish7
- python scripts/compare_wandb_configs.py https://wandb.ai/ai2-llm/olmo-medium/runs/cej4ya39 https://wandb.ai/ai2-llm/olmo-medium/runs/ij4ls6v2
"""

import logging
import os
import re
from collections import Counter

import click

from olmo.util import flatten_dict, prepare_cli_environment
from olmo.util import (
build_file_tree,
file_tree_to_strings,
flatten_dict,
prepare_cli_environment,
)

log = logging.getLogger(__name__)
run_path_re = re.compile(r"^[^/]+/[^/]+/[^/]+$")
Expand Down Expand Up @@ -58,13 +77,44 @@ def print_keys_with_differences(left_config, right_config, level=0):
s_shared += prefix + f"{k}\n\t{left_config[k]}\n" + prefix + f"\t{right_config[k]}\n\n"

if (s_left or s_right) and not s_shared:
s = s_left + s_right + prefix + "No differences in shared settings.\n"
s = s_left + "=" * 50 + "\n" + s_right + "=" * 50 + "\n" + prefix + "No differences in shared settings.\n"
else:
s = s_left + s_right + s_shared
s = s_left + "=" * 50 + "\n" + s_right + "=" * 50 + "\n" + s_shared
print(s.strip())
return


def simplify_path(path):
# Remove the variable prefix up to and including 'preprocessed/'
path = re.sub(r"^.*?preprocessed/", "", path)

# Split the path into components
components = path.split("/")

# Remove common suffixes like 'allenai'
components = [c for c in components if c != "allenai"]

if components:
return "__".join(components)
else:
return "unknown_dataset"


def print_data_differences(left_data_paths: Counter, right_data_paths: Counter):
print(f"===== Data Paths for left config:\n")
left_data_paths = {simplify_path(path): count for path, count in left_data_paths.items()}
for path, num_files in left_data_paths.items():
print(f"\t{path}: {num_files}")
print("\n\n")

print(f"===== Data Paths for right config:\n")
right_data_paths = {simplify_path(path): count for path, count in right_data_paths.items()}
for path, num_files in right_data_paths.items():
print(f"\t{path}: {num_files}")

return


@click.command()
@click.argument(
"left_run_path",
Expand All @@ -84,10 +134,61 @@ def main(
left_run = api.run(parse_run_path(left_run_path))
right_run = api.run(parse_run_path(right_run_path))

left_config = flatten_dict(left_run._attrs["rawconfig"], include_lists=True)
right_config = flatten_dict(right_run._attrs["rawconfig"], include_lists=True)

left_config_raw = left_run._attrs["rawconfig"]
right_config_raw = right_run._attrs["rawconfig"]

# flattening the dict will make diffs easier
left_config = flatten_dict(left_config_raw)
right_config = flatten_dict(right_config_raw)

# there are 2 specific fields in config that are difficult to diff:
# "evaluators" is List[Dict]
# "data.paths" is List[str]
# let's handle each of these directly.

# first, data.paths can be grouped and counted.
left_data_paths = Counter([os.path.dirname(path) for path in left_config["data.paths"]])
# for path, num_files in left_data_paths.items():
# new_key = "data.paths" + "." + path
# left_config[new_key] = f"Num Files: {num_files}"
right_data_paths = Counter([os.path.dirname(path) for path in right_config["data.paths"]])
# for path, num_files in right_data_paths.items():
# new_key = "data.paths" + "." + path
# right_config[new_key] = f"Num Files: {num_files}"
del left_config["data.paths"]
del right_config["data.paths"]

# next, evaluators can be added to the flat dict with unique key per evaluator
# also, each evaluator can also have a 'data.paths' field which needs collapsing
left_evaluators = {}
for evaluator in left_config["evaluators"]:
new_key = ".".join(["evaluators" + "." + evaluator["type"] + "." + evaluator["label"]])
if evaluator["data"]["paths"]:
evaluator["data"]["paths"] = Counter([os.path.dirname(path) for path in evaluator["data"]["paths"]])
left_evaluators[new_key] = evaluator
right_evaluators = {}
for evaluator in right_config["evaluators"]:
new_key = ".".join(["evaluators" + "." + evaluator["type"] + "." + evaluator["label"]])
if evaluator["data"]["paths"]:
evaluator["data"]["paths"] = Counter([os.path.dirname(path) for path in evaluator["data"]["paths"]])
right_evaluators[new_key] = evaluator
del left_config["evaluators"]
del right_config["evaluators"]

# print config differences
print(f"===== Config differences between {left_run_path} and {right_run_path}:\n")
print_keys_with_differences(left_config=left_config, right_config=right_config)
print("\n\n")

# print data differences
print(f"===== Data differences between {left_run_path} and {right_run_path}:\n")
print_data_differences(left_data_paths, right_data_paths)
print("\n\n")

# print eval differences
print(f"===== Evaluator differences between {left_run_path} and {right_run_path}:\n")
print_keys_with_differences(left_config=left_evaluators, right_config=right_evaluators)
print("\n\n")


if __name__ == "__main__":
Expand Down

0 comments on commit 147013f

Please sign in to comment.