diff --git a/scripts/microanneal_config_maker.py b/scripts/microanneal_config_maker.py index 8eb89e8d4..82b43f638 100644 --- a/scripts/microanneal_config_maker.py +++ b/scripts/microanneal_config_maker.py @@ -1,37 +1,37 @@ -import boto3 +import glob import os -from typing import List, Tuple -from urllib.parse import urlparse import random -import yaml -import glob -from tqdm.auto import tqdm -from botocore.exceptions import ClientError from collections import defaultdict -from tabulate import tabulate from concurrent.futures import ThreadPoolExecutor +from typing import List, Tuple +from urllib.parse import urlparse - +import boto3 +import yaml +from botocore.exceptions import ClientError +from tabulate import tabulate +from tqdm.auto import tqdm # =================================================== # = S3 HELPERS = # =================================================== + def get_single_s3_size(s3_uri: str, s3_client=None) -> int: # Gets the size in bytes of an individual s3 path if s3_client is None: - s3_client = boto3.client('s3') - + s3_client = boto3.client("s3") + parsed = urlparse(s3_uri) bucket_name = parsed.netloc # Remove leading slash and handle edge cases - object_key = parsed.path.lstrip('/') + object_key = parsed.path.lstrip("/") try: - s3_client = boto3.client('s3') + s3_client = boto3.client("s3") response = s3_client.head_object(Bucket=bucket_name, Key=object_key) - return (None, response['ContentLength']) + return (None, response["ContentLength"]) except Exception as e: - if e.response['Error']['Code'] == '404': + if e.response["Error"]["Code"] == "404": raise FileNotFoundError(f"The object {object_key} does not exist in bucket {bucket_name}") else: raise (e, 0) @@ -39,8 +39,9 @@ def get_single_s3_size(s3_uri: str, s3_client=None) -> int: def get_batch_s3_size(s3_uris: List[str]): # Faster way to get size in bytes for a lot of s3 paths: maps s3_uri -> size - s3_client = boto3.client('s3') + s3_client = boto3.client("s3") errors = [] + def partial_size(s3_uri: str): output, size = get_single_s3_size(s3_uri, s3_client=s3_client) if output != None: @@ -53,57 +54,55 @@ def partial_size(s3_uri: str): for future in tqdm(futures, total=len(futures)): results.append(future.result()) - # Convert results to dictionary sizes = dict(results) return sizes - - -def list_s3_paths(s3_uri: str, extension: str='.npy') -> List[Tuple[str, int]]: + + +def list_s3_paths(s3_uri: str, extension: str = ".npy") -> List[Tuple[str, int]]: """ Lists all paths in an S3 bucket with given prefix and extension, along with their sizes. - + Args: bucket_name (str): Name of the S3 bucket prefix (str): Prefix to filter objects (e.g., 'data/') extension (str): File extension to filter (e.g., '.csv') - + Returns: List[Tuple[str, int]]: List of tuples containing (path, size in bytes) """ parsed = urlparse(s3_uri) bucket_name = parsed.netloc - + # Remove leading slash and handle edge cases - prefix = parsed.path.lstrip('/') + prefix = parsed.path.lstrip("/") + + s3_client = boto3.client("s3") - - s3_client = boto3.client('s3') - # Ensure prefix ends with '/' if it's meant to be a directory - if prefix and not prefix.endswith('/'): - prefix += '/' - + if prefix and not prefix.endswith("/"): + prefix += "/" + # Ensure extension starts with '.' - if not extension.startswith('.'): - extension = '.' + extension - + if not extension.startswith("."): + extension = "." + extension + paths_and_sizes = [] - paginator = s3_client.get_paginator('list_objects_v2') - + paginator = s3_client.get_paginator("list_objects_v2") + try: # Paginate through results to handle large buckets for page in paginator.paginate(Bucket=bucket_name, Prefix=prefix): - if 'Contents' not in page: + if "Contents" not in page: continue - - for obj in page['Contents']: - key = obj['Key'] + + for obj in page["Contents"]: + key = obj["Key"] if key.endswith(extension): - paths_and_sizes.append((key, obj['Size'])) - + paths_and_sizes.append((key, obj["Size"])) + return paths_and_sizes - + except Exception as e: print(f"Error listing objects: {str(e)}") return [] @@ -113,8 +112,7 @@ def list_s3_paths(s3_uri: str, extension: str='.npy') -> List[Tuple[str, int]]: # = Other config-specific helpers = # ================================================================= -BASE_YAML_STR = \ -'''run_name: REPLACE_RUN_NAME_HERE +BASE_YAML_STR = """run_name: REPLACE_RUN_NAME_HERE seed: 7201 dry_run: false @@ -361,17 +359,17 @@ def list_s3_paths(s3_uri: str, extension: str='.npy') -> List[Tuple[str, int]]: repetition_max_period: 13 repetition_min_period: 1 repetition_max_count: 32 - paths:''' + paths:""" def human_format_number(num, decimal_places=2): """ Format a number using K for thousands, M for millions, B for billions, T for trillions. - + Args: num: Number to format decimal_places: Number of decimal places to show (default: 2) - + Examples: format_number(999) => '999' format_number(1000) => '1.00K' @@ -380,21 +378,21 @@ def human_format_number(num, decimal_places=2): format_number(1500000000) => '1.50B' """ abs_num = abs(num) - sign = '-' if num < 0 else '' - + sign = "-" if num < 0 else "" + if abs_num < 1000: return f"{sign}{abs_num}" - - suffixes = ['', 'K', 'M', 'B', 'T'] + + suffixes = ["", "K", "M", "B", "T"] magnitude = 0 - + while abs_num >= 1000 and magnitude < len(suffixes) - 1: abs_num /= 1000 magnitude += 1 - + # Format with specified decimal places formatted = f"{abs_num:.{decimal_places}f}" - + return f"{sign}{formatted}{suffixes[magnitude]}" @@ -407,53 +405,59 @@ def get_token_strs(token_source, bytes_per_token=4): paths_and_sizes = list_s3_paths(s3_source) parsed = urlparse(s3_source) - bucket_name = parsed.netloc - paths_and_sizes = [('s3://%s/%s' % (bucket_name, p), s) for p,s in paths_and_sizes] + bucket_name = parsed.netloc + paths_and_sizes = [("s3://%s/%s" % (bucket_name, p), s) for p, s in paths_and_sizes] random.shuffle(paths_and_sizes) total_tokens = sum(_[1] for _ in paths_and_sizes) // bytes_per_token target_tokens = total_tokens * ratio - + paths_to_add = [] tokens_to_add = 0 for p, s in paths_and_sizes: paths_to_add.append(p) tokens_to_add += s // bytes_per_token - if tokens_to_add >= target_tokens: + if tokens_to_add >= target_tokens: break - lines_to_add = ['#SOURCE: %s (%sT)' % (s3_source, human_format_number(tokens_to_add))] + lines_to_add = ["#SOURCE: %s (%sT)" % (s3_source, human_format_number(tokens_to_add))] for p in paths_to_add: - lines_to_add.append('- %s' % p) + lines_to_add.append("- %s" % p) return lines_to_add -def add_paths(token_sources, output_yaml_file, start_point='preanneal'): +def add_paths(token_sources, output_yaml_file, start_point="preanneal"): # Adds things to the yaml file. # Token sources is a list of either... s3_uri: str | (s3_uri: str, fraction: float) # Also I'm not bothering with pyyaml, just appending to the base config (which will be included) # ^this is a very crude stone-age tool, don't @ me - assert os.path.basename(output_yaml_file).startswith('peteish7-weka-microanneal') - assert output_yaml_file.endswith('.yaml') + assert os.path.basename(output_yaml_file).startswith("peteish7-weka-microanneal") + assert output_yaml_file.endswith(".yaml") - assert start_point in ['preanneal', 'megamath5000'] - base_config_str = BASE_YAML_STR.replace("REPLACE_RUN_NAME_HERE", os.path.splitext(os.path.basename(output_yaml_file))[0]) + assert start_point in ["preanneal", "megamath5000"] + base_config_str = BASE_YAML_STR.replace( + "REPLACE_RUN_NAME_HERE", os.path.splitext(os.path.basename(output_yaml_file))[0] + ) # Change input model, LR - if start_point == 'preanneal': - base_config_str = base_config_str.replace('REPLACE_PATH_HERE', '/weka/oe-training-default/ai2-llm/checkpoints/OLMo-medium/peteish7/step928646') - base_config_str = base_config_str.replace('REPLACE_LR_HERE', '0.000061499') - elif start_point == 'megamath5000': - base_config_str = base_config_str.replace('REPLACE_PATH_HERE', '/weka/oe-training-default/ai2-llm/checkpoints/OLMo-medium/peteish7-weka-anneal-from-928646-50B-megamath_v1.1.yaml/step5000/') - new_lr = '%.09f' % (0.000061499 * (1 - (5000 / 11931))) - base_config_str = base_config_str.replace('REPLACE_LR_HERE', new_lr) - - + if start_point == "preanneal": + base_config_str = base_config_str.replace( + "REPLACE_PATH_HERE", "/weka/oe-training-default/ai2-llm/checkpoints/OLMo-medium/peteish7/step928646" + ) + base_config_str = base_config_str.replace("REPLACE_LR_HERE", "0.000061499") + elif start_point == "megamath5000": + base_config_str = base_config_str.replace( + "REPLACE_PATH_HERE", + "/weka/oe-training-default/ai2-llm/checkpoints/OLMo-medium/peteish7-weka-anneal-from-928646-50B-megamath_v1.1.yaml/step5000/", + ) + new_lr = "%.09f" % (0.000061499 * (1 - (5000 / 11931))) + base_config_str = base_config_str.replace("REPLACE_LR_HERE", new_lr) + lines_to_add = [] for source in token_sources: lines_to_add.extend(get_token_strs(source)) - true_lines_to_add = ['\n %s' % line for line in lines_to_add] - output_str = base_config_str + ''.join(true_lines_to_add) - with open(output_yaml_file, 'w') as f: + true_lines_to_add = ["\n %s" % line for line in lines_to_add] + output_str = base_config_str + "".join(true_lines_to_add) + with open(output_yaml_file, "w") as f: f.write(output_str) @@ -466,11 +470,10 @@ def examine_config(yaml_file, bytes_per_token=4): print("Getting tokens per input file...") # Step 1: collect all paths of tokens - with open(yaml_file, 'r') as f: + with open(yaml_file, "r") as f: yaml_content = yaml.safe_load(f) - paths = yaml_content.get('data', {}).get('paths', []) - paths_to_tokens = {k: v // bytes_per_token for k,v in get_batch_s3_size(paths).items()} - + paths = yaml_content.get("data", {}).get("paths", []) + paths_to_tokens = {k: v // bytes_per_token for k, v in get_batch_s3_size(paths).items()} # Step 2: Gather all sources, count tokens taken print("Grouping output files into groups...") @@ -494,27 +497,28 @@ def get_group(s3_uri): total_tokens[g] = sum(_[1] for _ in paths_and_sizes) // bytes_per_token print("TOTAL_TOKENS", total_tokens) # Step 4: get ratios of percentage taken - ratios = {g: '%.04f' % (tokens_taken[g] / total_tokens[g]) - for g in groups} # .04f here (ranging from 0.00 to 1.00) + ratios = { + g: "%.04f" % (tokens_taken[g] / total_tokens[g]) for g in groups + } # .04f here (ranging from 0.00 to 1.00) # Step 5: actually print the outputs rows = sorted([(g, total_tokens[g], ratios[g], tokens_taken[g]) for g in groups]) print("Put this in your spreadsheet!") - print(tabulate(rows, headers=['paths', 'total_tokens', 'percentage taken', 'tokens taken'])) + print(tabulate(rows, headers=["paths", "total_tokens", "percentage taken", "tokens taken"])) def _read_path_comments(yaml_file): # This is helpful for examining paths - lines = open(yaml_file,'r').readlines() + lines = open(yaml_file, "r").readlines() path_sources = [] seen_paths = False for line in lines: - if not seen_paths and line.strip() != 'paths:': + if not seen_paths and line.strip() != "paths:": continue - elif line.strip() == 'paths:': + elif line.strip() == "paths:": seen_paths = True - elif line.strip().startswith('#'): - path_sources.append(line.strip().split(' ')[1]) + elif line.strip().startswith("#"): + path_sources.append(line.strip().split(" ")[1]) else: pass return path_sources @@ -525,7 +529,7 @@ def _read_path_comments(yaml_file): # ================================================= -if __name__ == '__main__': +if __name__ == "__main__": """ Use this interactively like `python -i peteish7_config_maker.py`, since tuples are weird to pass [ or load all these modules into a jupyter notebook ] @@ -539,4 +543,4 @@ def _read_path_comments(yaml_file): # and then you can populate the spreadsheet with the output of examine_config print(examine_config(OUTPUT_YAML)) - """ \ No newline at end of file + """ diff --git a/scripts/peteish7_config_maker.py b/scripts/peteish7_config_maker.py index 84fefe414..af7a9314c 100644 --- a/scripts/peteish7_config_maker.py +++ b/scripts/peteish7_config_maker.py @@ -1,37 +1,37 @@ -import boto3 +import glob import os -from typing import List, Tuple -from urllib.parse import urlparse import random -import yaml -import glob -from tqdm.auto import tqdm -from botocore.exceptions import ClientError from collections import defaultdict -from tabulate import tabulate from concurrent.futures import ThreadPoolExecutor +from typing import List, Tuple +from urllib.parse import urlparse - +import boto3 +import yaml +from botocore.exceptions import ClientError +from tabulate import tabulate +from tqdm.auto import tqdm # =================================================== # = S3 HELPERS = # =================================================== + def get_single_s3_size(s3_uri: str, s3_client=None) -> int: # Gets the size in bytes of an individual s3 path if s3_client is None: - s3_client = boto3.client('s3') - + s3_client = boto3.client("s3") + parsed = urlparse(s3_uri) bucket_name = parsed.netloc # Remove leading slash and handle edge cases - object_key = parsed.path.lstrip('/') + object_key = parsed.path.lstrip("/") try: - s3_client = boto3.client('s3') + s3_client = boto3.client("s3") response = s3_client.head_object(Bucket=bucket_name, Key=object_key) - return (None, response['ContentLength']) + return (None, response["ContentLength"]) except Exception as e: - if e.response['Error']['Code'] == '404': + if e.response["Error"]["Code"] == "404": raise FileNotFoundError(f"The object {object_key} does not exist in bucket {bucket_name}") else: raise (e, 0) @@ -39,8 +39,9 @@ def get_single_s3_size(s3_uri: str, s3_client=None) -> int: def get_batch_s3_size(s3_uris: List[str]): # Faster way to get size in bytes for a lot of s3 paths: maps s3_uri -> size - s3_client = boto3.client('s3') + s3_client = boto3.client("s3") errors = [] + def partial_size(s3_uri: str): output, size = get_single_s3_size(s3_uri, s3_client=s3_client) if output != None: @@ -53,60 +54,58 @@ def partial_size(s3_uri: str): for future in tqdm(futures, total=len(futures)): results.append(future.result()) - # Convert results to dictionary sizes = {} for s3_uri, size in results: sizes[s3_uri] = sizes.get(s3_uri, 0) + size # sizes = dict(results) return sizes - - -def list_s3_paths(s3_uri: str, extension: str='.npy') -> List[Tuple[str, int]]: + + +def list_s3_paths(s3_uri: str, extension: str = ".npy") -> List[Tuple[str, int]]: """ Lists all paths in an S3 bucket with given prefix and extension, along with their sizes. - + Args: bucket_name (str): Name of the S3 bucket prefix (str): Prefix to filter objects (e.g., 'data/') extension (str): File extension to filter (e.g., '.csv') - + Returns: List[Tuple[str, int]]: List of tuples containing (path, size in bytes) """ parsed = urlparse(s3_uri) bucket_name = parsed.netloc - + # Remove leading slash and handle edge cases - prefix = parsed.path.lstrip('/') + prefix = parsed.path.lstrip("/") + + s3_client = boto3.client("s3") - - s3_client = boto3.client('s3') - # Ensure prefix ends with '/' if it's meant to be a directory - if prefix and not prefix.endswith('/'): - prefix += '/' - + if prefix and not prefix.endswith("/"): + prefix += "/" + # Ensure extension starts with '.' - if not extension.startswith('.'): - extension = '.' + extension - + if not extension.startswith("."): + extension = "." + extension + paths_and_sizes = [] - paginator = s3_client.get_paginator('list_objects_v2') - + paginator = s3_client.get_paginator("list_objects_v2") + try: # Paginate through results to handle large buckets for page in paginator.paginate(Bucket=bucket_name, Prefix=prefix): - if 'Contents' not in page: + if "Contents" not in page: continue - - for obj in page['Contents']: - key = obj['Key'] + + for obj in page["Contents"]: + key = obj["Key"] if key.endswith(extension): - paths_and_sizes.append((key, obj['Size'])) - + paths_and_sizes.append((key, obj["Size"])) + return paths_and_sizes - + except Exception as e: print(f"Error listing objects: {str(e)}") return [] @@ -116,8 +115,7 @@ def list_s3_paths(s3_uri: str, extension: str='.npy') -> List[Tuple[str, int]]: # = Other config-specific helpers = # ================================================================= -BASE_YAML_STR = \ -'''run_name: REPLACE_RUN_NAME_HERE +BASE_YAML_STR = """run_name: REPLACE_RUN_NAME_HERE seed: 7201 dry_run: false @@ -364,17 +362,17 @@ def list_s3_paths(s3_uri: str, extension: str='.npy') -> List[Tuple[str, int]]: repetition_max_period: 13 repetition_min_period: 1 repetition_max_count: 32 - paths:''' + paths:""" def human_format_number(num, decimal_places=2): """ Format a number using K for thousands, M for millions, B for billions, T for trillions. - + Args: num: Number to format decimal_places: Number of decimal places to show (default: 2) - + Examples: format_number(999) => '999' format_number(1000) => '1.00K' @@ -383,21 +381,21 @@ def human_format_number(num, decimal_places=2): format_number(1500000000) => '1.50B' """ abs_num = abs(num) - sign = '-' if num < 0 else '' - + sign = "-" if num < 0 else "" + if abs_num < 1000: return f"{sign}{abs_num}" - - suffixes = ['', 'K', 'M', 'B', 'T'] + + suffixes = ["", "K", "M", "B", "T"] magnitude = 0 - + while abs_num >= 1000 and magnitude < len(suffixes) - 1: abs_num /= 1000 magnitude += 1 - + # Format with specified decimal places formatted = f"{abs_num:.{decimal_places}f}" - + return f"{sign}{formatted}{suffixes[magnitude]}" @@ -410,22 +408,22 @@ def get_token_strs(token_source, bytes_per_token=4): paths_and_sizes = list_s3_paths(s3_source) parsed = urlparse(s3_source) - bucket_name = parsed.netloc - paths_and_sizes = [('s3://%s/%s' % (bucket_name, p), s) for p,s in paths_and_sizes] + bucket_name = parsed.netloc + paths_and_sizes = [("s3://%s/%s" % (bucket_name, p), s) for p, s in paths_and_sizes] random.shuffle(paths_and_sizes) total_tokens = sum(_[1] for _ in paths_and_sizes) // bytes_per_token target_tokens = total_tokens * ratio - + paths_to_add = [] tokens_to_add = 0 for p, s in paths_and_sizes: paths_to_add.append(p) tokens_to_add += s // bytes_per_token - if tokens_to_add >= target_tokens: + if tokens_to_add >= target_tokens: break - lines_to_add = ['#SOURCE: %s (%sT)' % (s3_source, human_format_number(tokens_to_add))] + lines_to_add = ["#SOURCE: %s (%sT)" % (s3_source, human_format_number(tokens_to_add))] for p in paths_to_add: - lines_to_add.append('- %s' % p) + lines_to_add.append("- %s" % p) return lines_to_add @@ -434,17 +432,19 @@ def add_paths(token_sources, output_yaml_file): # Token sources is a list of either... s3_uri: str | (s3_uri: str, fraction: float) # Also I'm not bothering with pyyaml, just appending to the base config (which will be included) # ^this is a very crude stone-age tool, don't @ me - assert output_yaml_file.startswith('peteish7-weka-anneal-from-928646-50B-') - assert output_yaml_file.endswith('.yaml') + assert output_yaml_file.startswith("peteish7-weka-anneal-from-928646-50B-") + assert output_yaml_file.endswith(".yaml") + + base_config_str = BASE_YAML_STR.replace( + "REPLACE_RUN_NAME_HERE", os.path.splitext(os.path.basename(output_yaml_file))[0] + ) - base_config_str = BASE_YAML_STR.replace("REPLACE_RUN_NAME_HERE", os.path.splitext(os.path.basename(output_yaml_file))[0]) - lines_to_add = [] for source in token_sources: lines_to_add.extend(get_token_strs(source)) - true_lines_to_add = ['\n %s' % line for line in lines_to_add] - output_str = base_config_str + ''.join(true_lines_to_add) - with open(output_yaml_file, 'w') as f: + true_lines_to_add = ["\n %s" % line for line in lines_to_add] + output_str = base_config_str + "".join(true_lines_to_add) + with open(output_yaml_file, "w") as f: f.write(output_str) @@ -457,11 +457,10 @@ def examine_config(yaml_file, bytes_per_token=4): print("Getting tokens per input file...") # Step 1: collect all paths of tokens - with open(yaml_file, 'r') as f: + with open(yaml_file, "r") as f: yaml_content = yaml.safe_load(f) - paths = yaml_content.get('data', {}).get('paths', []) - paths_to_tokens = {k: v // bytes_per_token for k,v in get_batch_s3_size(paths).items()} - + paths = yaml_content.get("data", {}).get("paths", []) + paths_to_tokens = {k: v // bytes_per_token for k, v in get_batch_s3_size(paths).items()} # Step 2: Gather all sources, count tokens taken print("Grouping output files into groups...") @@ -484,27 +483,28 @@ def get_group(s3_uri): paths_and_sizes = list_s3_paths(g) total_tokens[g] = sum(_[1] for _ in paths_and_sizes) // bytes_per_token # Step 4: get ratios of percentage taken - ratios = {g: '%.04f' % (tokens_taken[g] / total_tokens[g]) - for g in groups} # .04f here (ranging from 0.00 to 1.00) + ratios = { + g: "%.04f" % (tokens_taken[g] / total_tokens[g]) for g in groups + } # .04f here (ranging from 0.00 to 1.00) # Step 5: actually print the outputs rows = sorted([(g, total_tokens[g], ratios[g], tokens_taken[g]) for g in groups]) print("Put this in your spreadsheet!") - print(tabulate(rows, headers=['paths', 'total_tokens', 'percentage taken', 'tokens taken'])) + print(tabulate(rows, headers=["paths", "total_tokens", "percentage taken", "tokens taken"])) def _read_path_comments(yaml_file): # This is helpful for examining paths - lines = open(yaml_file,'r').readlines() + lines = open(yaml_file, "r").readlines() path_sources = [] seen_paths = False for line in lines: - if not seen_paths and line.strip() != 'paths:': + if not seen_paths and line.strip() != "paths:": continue - elif line.strip() == 'paths:': + elif line.strip() == "paths:": seen_paths = True - elif line.strip().startswith('#'): - path_sources.append(line.strip().split(' ')[1]) + elif line.strip().startswith("#"): + path_sources.append(line.strip().split(" ")[1]) else: pass return path_sources @@ -515,7 +515,7 @@ def _read_path_comments(yaml_file): # ================================================= -if __name__ == '__main__': +if __name__ == "__main__": """ Use this interactively like `python -i peteish7_config_maker.py`, since tuples are weird to pass [ or load all these modules into a jupyter notebook ] @@ -529,4 +529,4 @@ def _read_path_comments(yaml_file): # and then you can populate the spreadsheet with the output of examine_config print(examine_config(OUTPUT_YAML)) - """ \ No newline at end of file + """