Skip to content

Commit

Permalink
Formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
dirkgr committed Nov 26, 2024
1 parent d7623fb commit e892ce5
Show file tree
Hide file tree
Showing 2 changed files with 176 additions and 172 deletions.
186 changes: 95 additions & 91 deletions scripts/microanneal_config_maker.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,47 @@
import boto3
import glob
import os
from typing import List, Tuple
from urllib.parse import urlparse
import random
import yaml
import glob
from tqdm.auto import tqdm
from botocore.exceptions import ClientError
from collections import defaultdict
from tabulate import tabulate
from concurrent.futures import ThreadPoolExecutor
from typing import List, Tuple
from urllib.parse import urlparse


import boto3
import yaml
from botocore.exceptions import ClientError
from tabulate import tabulate
from tqdm.auto import tqdm

# ===================================================
# = S3 HELPERS =
# ===================================================


def get_single_s3_size(s3_uri: str, s3_client=None) -> int:
# Gets the size in bytes of an individual s3 path
if s3_client is None:
s3_client = boto3.client('s3')
s3_client = boto3.client("s3")

parsed = urlparse(s3_uri)
bucket_name = parsed.netloc
# Remove leading slash and handle edge cases
object_key = parsed.path.lstrip('/')
object_key = parsed.path.lstrip("/")
try:
s3_client = boto3.client('s3')
s3_client = boto3.client("s3")
response = s3_client.head_object(Bucket=bucket_name, Key=object_key)
return (None, response['ContentLength'])
return (None, response["ContentLength"])
except Exception as e:
if e.response['Error']['Code'] == '404':
if e.response["Error"]["Code"] == "404":
raise FileNotFoundError(f"The object {object_key} does not exist in bucket {bucket_name}")
else:
raise (e, 0)


def get_batch_s3_size(s3_uris: List[str]):
# Faster way to get size in bytes for a lot of s3 paths: maps s3_uri -> size
s3_client = boto3.client('s3')
s3_client = boto3.client("s3")
errors = []

def partial_size(s3_uri: str):
output, size = get_single_s3_size(s3_uri, s3_client=s3_client)
if output != None:
Expand All @@ -53,57 +54,55 @@ def partial_size(s3_uri: str):
for future in tqdm(futures, total=len(futures)):
results.append(future.result())


# Convert results to dictionary
sizes = dict(results)
return sizes
def list_s3_paths(s3_uri: str, extension: str='.npy') -> List[Tuple[str, int]]:


def list_s3_paths(s3_uri: str, extension: str = ".npy") -> List[Tuple[str, int]]:
"""
Lists all paths in an S3 bucket with given prefix and extension, along with their sizes.
Args:
bucket_name (str): Name of the S3 bucket
prefix (str): Prefix to filter objects (e.g., 'data/')
extension (str): File extension to filter (e.g., '.csv')
Returns:
List[Tuple[str, int]]: List of tuples containing (path, size in bytes)
"""
parsed = urlparse(s3_uri)
bucket_name = parsed.netloc

# Remove leading slash and handle edge cases
prefix = parsed.path.lstrip('/')
prefix = parsed.path.lstrip("/")

s3_client = boto3.client("s3")


s3_client = boto3.client('s3')

# Ensure prefix ends with '/' if it's meant to be a directory
if prefix and not prefix.endswith('/'):
prefix += '/'
if prefix and not prefix.endswith("/"):
prefix += "/"

# Ensure extension starts with '.'
if not extension.startswith('.'):
extension = '.' + extension
if not extension.startswith("."):
extension = "." + extension

paths_and_sizes = []
paginator = s3_client.get_paginator('list_objects_v2')
paginator = s3_client.get_paginator("list_objects_v2")

try:
# Paginate through results to handle large buckets
for page in paginator.paginate(Bucket=bucket_name, Prefix=prefix):
if 'Contents' not in page:
if "Contents" not in page:
continue
for obj in page['Contents']:
key = obj['Key']

for obj in page["Contents"]:
key = obj["Key"]
if key.endswith(extension):
paths_and_sizes.append((key, obj['Size']))
paths_and_sizes.append((key, obj["Size"]))

return paths_and_sizes

except Exception as e:
print(f"Error listing objects: {str(e)}")
return []
Expand All @@ -113,8 +112,7 @@ def list_s3_paths(s3_uri: str, extension: str='.npy') -> List[Tuple[str, int]]:
# = Other config-specific helpers =
# =================================================================

BASE_YAML_STR = \
'''run_name: REPLACE_RUN_NAME_HERE
BASE_YAML_STR = """run_name: REPLACE_RUN_NAME_HERE
seed: 7201
dry_run: false
Expand Down Expand Up @@ -361,17 +359,17 @@ def list_s3_paths(s3_uri: str, extension: str='.npy') -> List[Tuple[str, int]]:
repetition_max_period: 13
repetition_min_period: 1
repetition_max_count: 32
paths:'''
paths:"""


def human_format_number(num, decimal_places=2):
"""
Format a number using K for thousands, M for millions, B for billions, T for trillions.
Args:
num: Number to format
decimal_places: Number of decimal places to show (default: 2)
Examples:
format_number(999) => '999'
format_number(1000) => '1.00K'
Expand All @@ -380,21 +378,21 @@ def human_format_number(num, decimal_places=2):
format_number(1500000000) => '1.50B'
"""
abs_num = abs(num)
sign = '-' if num < 0 else ''
sign = "-" if num < 0 else ""

if abs_num < 1000:
return f"{sign}{abs_num}"
suffixes = ['', 'K', 'M', 'B', 'T']

suffixes = ["", "K", "M", "B", "T"]
magnitude = 0

while abs_num >= 1000 and magnitude < len(suffixes) - 1:
abs_num /= 1000
magnitude += 1

# Format with specified decimal places
formatted = f"{abs_num:.{decimal_places}f}"

return f"{sign}{formatted}{suffixes[magnitude]}"


Expand All @@ -407,53 +405,59 @@ def get_token_strs(token_source, bytes_per_token=4):

paths_and_sizes = list_s3_paths(s3_source)
parsed = urlparse(s3_source)
bucket_name = parsed.netloc
paths_and_sizes = [('s3://%s/%s' % (bucket_name, p), s) for p,s in paths_and_sizes]
bucket_name = parsed.netloc
paths_and_sizes = [("s3://%s/%s" % (bucket_name, p), s) for p, s in paths_and_sizes]
random.shuffle(paths_and_sizes)
total_tokens = sum(_[1] for _ in paths_and_sizes) // bytes_per_token
target_tokens = total_tokens * ratio

paths_to_add = []
tokens_to_add = 0
for p, s in paths_and_sizes:
paths_to_add.append(p)
tokens_to_add += s // bytes_per_token
if tokens_to_add >= target_tokens:
if tokens_to_add >= target_tokens:
break
lines_to_add = ['#SOURCE: %s (%sT)' % (s3_source, human_format_number(tokens_to_add))]
lines_to_add = ["#SOURCE: %s (%sT)" % (s3_source, human_format_number(tokens_to_add))]
for p in paths_to_add:
lines_to_add.append('- %s' % p)
lines_to_add.append("- %s" % p)
return lines_to_add


def add_paths(token_sources, output_yaml_file, start_point='preanneal'):
def add_paths(token_sources, output_yaml_file, start_point="preanneal"):
# Adds things to the yaml file.
# Token sources is a list of either... s3_uri: str | (s3_uri: str, fraction: float)
# Also I'm not bothering with pyyaml, just appending to the base config (which will be included)
# ^this is a very crude stone-age tool, don't @ me

assert os.path.basename(output_yaml_file).startswith('peteish7-weka-microanneal')
assert output_yaml_file.endswith('.yaml')
assert os.path.basename(output_yaml_file).startswith("peteish7-weka-microanneal")
assert output_yaml_file.endswith(".yaml")

assert start_point in ['preanneal', 'megamath5000']
base_config_str = BASE_YAML_STR.replace("REPLACE_RUN_NAME_HERE", os.path.splitext(os.path.basename(output_yaml_file))[0])
assert start_point in ["preanneal", "megamath5000"]
base_config_str = BASE_YAML_STR.replace(
"REPLACE_RUN_NAME_HERE", os.path.splitext(os.path.basename(output_yaml_file))[0]
)

# Change input model, LR
if start_point == 'preanneal':
base_config_str = base_config_str.replace('REPLACE_PATH_HERE', '/weka/oe-training-default/ai2-llm/checkpoints/OLMo-medium/peteish7/step928646')
base_config_str = base_config_str.replace('REPLACE_LR_HERE', '0.000061499')
elif start_point == 'megamath5000':
base_config_str = base_config_str.replace('REPLACE_PATH_HERE', '/weka/oe-training-default/ai2-llm/checkpoints/OLMo-medium/peteish7-weka-anneal-from-928646-50B-megamath_v1.1.yaml/step5000/')
new_lr = '%.09f' % (0.000061499 * (1 - (5000 / 11931)))
base_config_str = base_config_str.replace('REPLACE_LR_HERE', new_lr)


if start_point == "preanneal":
base_config_str = base_config_str.replace(
"REPLACE_PATH_HERE", "/weka/oe-training-default/ai2-llm/checkpoints/OLMo-medium/peteish7/step928646"
)
base_config_str = base_config_str.replace("REPLACE_LR_HERE", "0.000061499")
elif start_point == "megamath5000":
base_config_str = base_config_str.replace(
"REPLACE_PATH_HERE",
"/weka/oe-training-default/ai2-llm/checkpoints/OLMo-medium/peteish7-weka-anneal-from-928646-50B-megamath_v1.1.yaml/step5000/",
)
new_lr = "%.09f" % (0.000061499 * (1 - (5000 / 11931)))
base_config_str = base_config_str.replace("REPLACE_LR_HERE", new_lr)

lines_to_add = []
for source in token_sources:
lines_to_add.extend(get_token_strs(source))
true_lines_to_add = ['\n %s' % line for line in lines_to_add]
output_str = base_config_str + ''.join(true_lines_to_add)
with open(output_yaml_file, 'w') as f:
true_lines_to_add = ["\n %s" % line for line in lines_to_add]
output_str = base_config_str + "".join(true_lines_to_add)
with open(output_yaml_file, "w") as f:
f.write(output_str)


Expand All @@ -466,11 +470,10 @@ def examine_config(yaml_file, bytes_per_token=4):

print("Getting tokens per input file...")
# Step 1: collect all paths of tokens
with open(yaml_file, 'r') as f:
with open(yaml_file, "r") as f:
yaml_content = yaml.safe_load(f)
paths = yaml_content.get('data', {}).get('paths', [])
paths_to_tokens = {k: v // bytes_per_token for k,v in get_batch_s3_size(paths).items()}

paths = yaml_content.get("data", {}).get("paths", [])
paths_to_tokens = {k: v // bytes_per_token for k, v in get_batch_s3_size(paths).items()}

# Step 2: Gather all sources, count tokens taken
print("Grouping output files into groups...")
Expand All @@ -494,27 +497,28 @@ def get_group(s3_uri):
total_tokens[g] = sum(_[1] for _ in paths_and_sizes) // bytes_per_token
print("TOTAL_TOKENS", total_tokens)
# Step 4: get ratios of percentage taken
ratios = {g: '%.04f' % (tokens_taken[g] / total_tokens[g])
for g in groups} # .04f here (ranging from 0.00 to 1.00)
ratios = {
g: "%.04f" % (tokens_taken[g] / total_tokens[g]) for g in groups
} # .04f here (ranging from 0.00 to 1.00)

# Step 5: actually print the outputs
rows = sorted([(g, total_tokens[g], ratios[g], tokens_taken[g]) for g in groups])
print("Put this in your spreadsheet!")
print(tabulate(rows, headers=['paths', 'total_tokens', 'percentage taken', 'tokens taken']))
print(tabulate(rows, headers=["paths", "total_tokens", "percentage taken", "tokens taken"]))


def _read_path_comments(yaml_file):
# This is helpful for examining paths
lines = open(yaml_file,'r').readlines()
lines = open(yaml_file, "r").readlines()
path_sources = []
seen_paths = False
for line in lines:
if not seen_paths and line.strip() != 'paths:':
if not seen_paths and line.strip() != "paths:":
continue
elif line.strip() == 'paths:':
elif line.strip() == "paths:":
seen_paths = True
elif line.strip().startswith('#'):
path_sources.append(line.strip().split(' ')[1])
elif line.strip().startswith("#"):
path_sources.append(line.strip().split(" ")[1])
else:
pass
return path_sources
Expand All @@ -525,7 +529,7 @@ def _read_path_comments(yaml_file):
# =================================================


if __name__ == '__main__':
if __name__ == "__main__":
"""
Use this interactively like `python -i peteish7_config_maker.py`, since tuples are weird to pass
[ or load all these modules into a jupyter notebook ]
Expand All @@ -539,4 +543,4 @@ def _read_path_comments(yaml_file):
# and then you can populate the spreadsheet with the output of examine_config
print(examine_config(OUTPUT_YAML))
"""
"""
Loading

0 comments on commit e892ce5

Please sign in to comment.