Skip to content

Commit

Permalink
Merge pull request #687 from allenai/kylel/config-diff
Browse files Browse the repository at this point in the history
Extend functionality of Wandb Config Diff script
  • Loading branch information
kyleclo authored Oct 30, 2024
2 parents 8589c38 + f2c2a15 commit 837a4ff
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 26 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added ability to try loading latest checkpoint from save folder using `--try_load_latest_save`.
- Added support for flash attention and gradient checkpointing to `hf_olmo`.
- Added to `scripts.compare_wandb_configs.py` the ability to more easily compare differences in data mixes and evaluation tasks.
- Added `effective_n_kv_heads` to OLMoConfig for hacky VLLM support.

## [v0.5.0](https://github.com/allenai/OLMo/releases/tag/v0.5.0) - 2024-08-26
Expand Down
20 changes: 18 additions & 2 deletions olmo/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,12 +841,28 @@ def get_bytes_range(self, index: int, length: int) -> bytes:
return response["Body"].read()


def flatten_dict(dictionary, parent_key="", separator="."):
def flatten_dict(dictionary, parent_key="", separator=".", include_lists=False):
"""
Flatten a nested dictionary into a single-level dictionary.
Args:
dictionary (dict): The nested dictionary to be flattened.
parent_key (str, optional): The parent key to be prepended to the keys of the flattened dictionary. Defaults to "".
separator (str, optional): The separator to be used between the parent key and the keys of the flattened dictionary. Defaults to ".".
include_lists (bool, optional): Whether to convert lists to dictionaries with integer keys. Defaults to False.
Returns:
dict: The flattened dictionary.
"""
d: Dict[str, Any] = {}
for key, value in dictionary.items():
new_key = parent_key + separator + key if parent_key else key
# convert lists to dict with key <int>
if isinstance(value, list) and include_lists:
value = {f"{i}": v for i, v in enumerate(value)}
if isinstance(value, MutableMapping):
d.update(**flatten_dict(value, new_key, separator=separator))
d.update(**flatten_dict(value, new_key, separator=separator, include_lists=include_lists))
else:
d[new_key] = value
return d
135 changes: 111 additions & 24 deletions scripts/compare_wandb_configs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
"""
Examples:
Comparing Peteish7 to OLMoE
- python scripts/compare_wandb_configs.py https://wandb.ai/ai2-llm/olmo-medium/runs/cej4ya39 https://wandb.ai/ai2-llm/olmoe/runs/rzsn9tlc
Comparing Peteish7 to Amberish7
- python scripts/compare_wandb_configs.py https://wandb.ai/ai2-llm/olmo-medium/runs/cej4ya39 https://wandb.ai/ai2-llm/olmo-medium/runs/ij4ls6v2
"""

import logging
import os
import re
from collections import Counter

import click

Expand All @@ -24,6 +38,47 @@ def parse_run_path(run_path: str) -> str:
raise ValueError(f"Could not parse '{run_path}'")


def print_keys_with_differences(left_config, right_config):
s_left = ""
left_only_keys = left_config.keys() - right_config.keys()
if len(left_only_keys) > 0:
s_left += "Settings only in left:\n"
s_left += "\n".join(f"\t{k}: {left_config[k]}" for k in sorted(left_only_keys)) + "\n"

s_right = ""
right_only_keys = right_config.keys() - left_config.keys()
if len(right_only_keys) > 0:
s_right += "Settings only in right:\n"
s_right += "\n".join(f"\t{k}: {right_config[k]}" for k in sorted(right_only_keys)) + "\n"

s_shared = ""
keys_with_differences = {
k for k in left_config.keys() & right_config.keys() if left_config[k] != right_config[k]
}
if len(keys_with_differences) > 0:
for k in sorted(keys_with_differences):
s_shared += f"{k}\n\t{left_config[k]}\n" + f"\t{right_config[k]}\n\n"

if (s_left or s_right) and not s_shared:
s = s_left + "=" * 50 + "\n" + s_right + "=" * 50 + "\n" + "No differences in shared settings.\n"
else:
s = s_left + "=" * 50 + "\n" + s_right + "=" * 50 + "\n" + s_shared
print(s.strip())


def print_data_differences(left_data_paths: Counter, right_data_paths: Counter):
print("===== Data Paths for left config:\n")
simplified_left_data_paths = {path: count for path, count in left_data_paths.items()}
for path, num_files in simplified_left_data_paths.items():
print(f"\t{path}: {num_files}")
print("\n\n")

print("===== Data Paths for right config:\n")
simplified_right_data_paths = {path: count for path, count in right_data_paths.items()}
for path, num_files in simplified_right_data_paths.items():
print(f"\t{path}: {num_files}")


@click.command()
@click.argument(
"left_run_path",
Expand All @@ -43,30 +98,62 @@ def main(
left_run = api.run(parse_run_path(left_run_path))
right_run = api.run(parse_run_path(right_run_path))

left_config = flatten_dict(left_run._attrs["rawconfig"])
right_config = flatten_dict(right_run._attrs["rawconfig"])

left_only_keys = left_config.keys() - right_config.keys()
if len(left_only_keys) > 0:
print("Settings only in left:")
print("\n".join(f"\t{k}: {left_config[k]}" for k in sorted(left_only_keys)))
print()

right_only_keys = right_config.keys() - left_config.keys()
if len(right_only_keys) > 0:
print("Settings only in right:")
print("\n".join(f"\t{k}: {right_config[k]}" for k in sorted(right_only_keys)))
print()

keys_with_differences = {
k for k in left_config.keys() & right_config.keys() if left_config[k] != right_config[k]
}
if len(keys_with_differences) > 0:
if len(left_only_keys) > 0 or len(right_only_keys) > 0:
print("Settings with differences:")
print("\n".join(f"{k}\n\t{left_config[k]}\n\t{right_config[k]}\n" for k in sorted(keys_with_differences)))
else:
print("No differences in shared settings.")
left_config_raw = left_run._attrs["rawconfig"]
right_config_raw = right_run._attrs["rawconfig"]

# flattening the dict will make diffs easier
left_config = flatten_dict(left_config_raw)
right_config = flatten_dict(right_config_raw)

# there are 2 specific fields in config that are difficult to diff:
# "evaluators" is List[Dict]
# "data.paths" is List[str]
# let's handle each of these directly.

# first, data.paths can be grouped and counted.
left_data_paths = Counter([os.path.dirname(path) for path in left_config["data.paths"]])
right_data_paths = Counter([os.path.dirname(path) for path in right_config["data.paths"]])
del left_config["data.paths"]
del right_config["data.paths"]

# next, evaluators can be added to the flat dict with unique key per evaluator
# also, each evaluator can also have a 'data.paths' field which needs collapsing
def _simplify_evaluator(evaluator):
evaluator = flatten_dict(evaluator)
if evaluator["data.paths"]:
evaluator["data.paths"] = Counter([os.path.dirname(path) for path in evaluator["data.paths"]])
return evaluator

def _simplify_evaluators(evaluators):
simplified_evaluators = {}
for evaluator in evaluators:
new_key = (".".join(["evaluators" + "." + evaluator["type"] + "." + evaluator["label"]])).upper()
simplified_evaluators[new_key] = _simplify_evaluator(evaluator)
return simplified_evaluators

left_evaluators = flatten_dict(_simplify_evaluators(left_config["evaluators"]), separator="___")
right_evaluators = flatten_dict(_simplify_evaluators(right_config["evaluators"]), separator="___")
del left_config["evaluators"]
del right_config["evaluators"]

print(
f"==================== Config differences between {left_run_path} and {right_run_path} ====================\n\n"
)

# print config differences
print("==================== Param differences ====================\n\n")
print_keys_with_differences(left_config=left_config, right_config=right_config)
print("============================================================= \n\n")

# print data differences
print("==================== Data Differences ====================\n\n")
print_data_differences(left_data_paths, right_data_paths)
print("============================================================= \n\n")

# print eval differences
print("==================== Eval Differences ====================\n\n")
print_keys_with_differences(left_config=left_evaluators, right_config=right_evaluators)
print("============================================================= \n\n")


if __name__ == "__main__":
Expand Down
28 changes: 28 additions & 0 deletions tests/util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,31 @@ def test_dir_is_empty(tmp_path):
# Should return false if dir contains anything, even hidden files.
(dir / ".foo").touch()
assert not util.dir_is_empty(dir)


def test_flatten_dict():
# basic flattening
test_dict = {"a": 0, "b": {"e": 5, "f": 1}, "c": 2}
assert util.flatten_dict(test_dict) == {"a": 0, "b.e": 5, "b.f": 1, "c": 2}

# Should flatten nested dicts into a single dict with dotted keys.
test_dict_with_list_of_dicts = {
"a": 0,
"b": {"e": [{"x": {"z": [222, 333]}}, {"y": {"g": [99, 100]}}], "f": 1},
"c": 2,
}
assert util.flatten_dict(test_dict_with_list_of_dicts) == {
"a": 0,
"b.e": [{"x": {"z": [222, 333]}}, {"y": {"g": [99, 100]}}], # doesnt get flattened
"b.f": 1,
"c": 2,
}
assert util.flatten_dict(test_dict_with_list_of_dicts, include_lists=True) == {
"a": 0,
"b.e.0.x.z.0": 222,
"b.e.0.x.z.1": 333,
"b.e.1.y.g.0": 99,
"b.e.1.y.g.1": 100,
"b.f": 1,
"c": 2,
}

0 comments on commit 837a4ff

Please sign in to comment.