diff --git a/.gitignore b/.gitignore index c714b3e6f..bce8ecbc3 100644 --- a/.gitignore +++ b/.gitignore @@ -31,7 +31,6 @@ pyrightconfig.json doc/_build/ *.swp .DS_Store -readme_misc.md # python diff --git a/README.md b/README.md index f08f911d6..398bfef4f 100644 --- a/README.md +++ b/README.md @@ -119,12 +119,12 @@ Example: To download checkpoint at step 2000: ```bash python scripts/download_checkpoints.py checkpoints/official/OLMo-1B.csv --save-dir ./checkpoints/ --step 2000 ``` -**Note**: All checkpoints in `checkpoints/official/` are unsharded files. +**Note**: All checkpoints in `checkpoints/official/` are unsharded. 2. Resume training using the downloaded checkpoint. You can specify either a local path or URL using the --load_path argument: For example, to resume training from step 2000 of the OLMo 1B run: ```bash -torchrun --nproc_per_node=8 scripts/train.py configs/official/OLMo-1B.yaml --load_path=checkpoints/step2000 --save_folder=./new_checkpoints --run_name=olmo_test --save_overwrite +torchrun --nproc_per_node=8 scripts/train.py configs/official/OLMo-1B.yaml --load_path=checkpoints/step2000 --save_folder=./new_checkpoints --run_name=olmo_test ``` The command above: - Loads the checkpoint from `checkpoints/step2000` @@ -133,9 +133,6 @@ The command above: - Overwrites existing checkpoints in the save folder. ### Inspecting training data - -To inspect the exact tokens used in training batches for OLMo models, first download the training data. If you don't have an R2 API key, use the public HTTP URLs and update your configuration file with the local data paths. After completing this setup, you can use the inspection tools to examine the training batches. - Find the data order file URL in the [Models Overview](#models-overview) table. For example, the OLMo-7B model's first epoch data order file is located at [https://olmo-checkpoints.org/ai2-llm/olmo-medium/wvc30anm/train_data/global_indices.npy](https://olmo-checkpoints.org/ai2-llm/olmo-small/46zc5fly/train_data/global_indices.npy). Once you have that you can use this snippet to inspect the data within a particular batch: diff --git a/scripts/download_checkpoints.py b/scripts/download_checkpoints.py index c3237020d..4fa1b6183 100644 --- a/scripts/download_checkpoints.py +++ b/scripts/download_checkpoints.py @@ -41,55 +41,58 @@ def try_get_directory_listing(url): "model.safetensors", "optim.safetensors", ] - found_files = [] for pattern in common_files: - test_url = urljoin(url.rstrip('/') + '/', pattern) try: + test_url = urljoin(url.rstrip('/') + '/', pattern) response = requests.head(test_url) + # response.raise_for_status() if response.status_code == 200: found_files.append(pattern) - except requests.exceptions.RequestException: - continue - + except requests.exceptions.HTTPError as e: + print(f"HTTP error for {pattern}: {e}") + except requests.exceptions.RequestException as e: + print(f"Connection error for {pattern}: {e}") return found_files def download_checkpoint(url, save_dir): - """Download all files from a checkpoint directory.""" - r2_url = convert_to_r2_url(url) - public_url = convert_to_public_url(r2_url) - - base_path = Path(save_dir) - base_path.mkdir(parents=True, exist_ok=True) - - print(f"\nR2 URL: {r2_url}") - print(f"Public URL: {public_url}") - print(f"Saving to: {base_path}") - - print("Checking for available files...") - available_files = try_get_directory_listing(public_url) - - if not available_files: - print("No files found using common patterns. The directory might be empty or use different file patterns.") - return - - for file in available_files: - file_url = urljoin(public_url.rstrip('/') + '/', file) - file_path = base_path / file - - try: - print(f"\nDownloading: {file}") - download_file(file_url, file_path) - except requests.exceptions.RequestException as e: - print(f"Error downloading {file}: {e}") - continue + """Download all files from a checkpoint directory.""" + r2_url = convert_to_r2_url(url) + public_url = convert_to_public_url(r2_url) + base_path = Path(save_dir) + base_path.mkdir(parents=True, exist_ok=True) + print(f"Saving to: {base_path}") + available_files = try_get_directory_listing(public_url) + + if not available_files: + raise ValueError("No matching files found in directory") + + failed_files = [] + for file in available_files: + file_url = urljoin(public_url.rstrip('/') + '/', file) + file_path = base_path / file + try: + print(f"\nDownloading: {file}") + download_file(file_url, file_path) + except requests.exceptions.Timeout: + print(f"Timeout error for {file}, retrying once...") + try: + download_file(file_url, file_path) + except requests.exceptions.RequestException as e: + failed_files.append(file) + print(f"Failed to download {file}: {e}") + except requests.exceptions.RequestException as e: + failed_files.append(file) + print(f"Failed to download {file}: {e}") + if failed_files: + print(f"\nWARNING: Failed to download these files: {failed_files}") def main(): parser = argparse.ArgumentParser(description='Download OLMo checkpoints from CSV') parser.add_argument('csv_file', type=str, help='Path to the CSV file containing checkpoint URLs') parser.add_argument('--save-dir', type=str, default='./checkpoints', help='Base directory to save downloaded checkpoints') - parser.add_argument('--step', type=str, help='Specific step number to download (optional)') + parser.add_argument('--step', type=str, default='1000', help='Specific step number to download.') parser.add_argument('--list-steps', action='store_true', help='List available step numbers and exit') args = parser.parse_args() @@ -101,7 +104,7 @@ def main(): urls = [(row['Step'], row['Checkpoint Directory']) for row in reader] if args.list_steps: - print("\nAvailable steps:") + print("Available steps:") for step, _ in urls: print(f"Step {step}") return @@ -114,26 +117,14 @@ def main(): return print(f"Saving checkpoints to: {args.save_dir}") - print("\nURL conversions:") for step, url in urls: r2_url = convert_to_r2_url(url) public_url = convert_to_public_url(r2_url) print(f"\nStep {step}:") - print(f"Original URL: {url}") - print(f"R2 URL: {r2_url}") print(f"Public URL: {public_url}") - - proceed = input("\nDo you want to proceed with the download? (y/n): ") - if proceed.lower() != 'y': - print("Download cancelled.") - return - - for step, url in urls: save_path = os.path.join(args.save_dir, f"step{step}") - try: - download_checkpoint(url, save_path) - except Exception as e: - print(f"Error during download of step {step}: {e}") + download_checkpoint(url, save_path) + if __name__ == "__main__": main() \ No newline at end of file diff --git a/scripts/train.py b/scripts/train.py index 1baffc973..ff7bb31b8 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -268,10 +268,9 @@ def dummy_init_fn(module: torch.nn.Module) -> None: ) cfg.save_num_unsharded_checkpoints_to_keep = cfg.save_num_checkpoints_to_keep elif cfg.distributed_strategy == DistributedStrategy.fsdp: - # checkpoint_type = ( - # CheckpointType.sharded if cfg.save_num_checkpoints_to_keep != 0 else CheckpointType.unsharded - # ) - checkpoint_type = CheckpointType.unsharded + checkpoint_type = ( + CheckpointType.sharded if cfg.save_num_checkpoints_to_keep != 0 else CheckpointType.unsharded + ) else: raise NotImplementedError(f"Distributed strategy {cfg.distributed_strategy} not supported yet!") @@ -298,9 +297,7 @@ def dummy_init_fn(module: torch.nn.Module) -> None: cfg.load_path, load_optimizer_state=not cfg.reset_optimizer_state, load_trainer_state=not cfg.reset_trainer_state, - # sharded_checkpointer=cfg.load_path_sharded_checkpointer, - sharded_checkpointer= False, - checkpoint_type=CheckpointType.unsharded + sharded_checkpointer=cfg.load_path_sharded_checkpointer, ) log.info("Checkpoint successfully loaded")