Skip to content

Commit

Permalink
updated code after Dirk's review
Browse files Browse the repository at this point in the history
  • Loading branch information
aman-17 committed Nov 20, 2024
1 parent a622fb0 commit 8aac2ea
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 62 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ pyrightconfig.json
doc/_build/
*.swp
.DS_Store
readme_misc.md

# python

Expand Down
7 changes: 2 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,12 @@ Example: To download checkpoint at step 2000:
```bash
python scripts/download_checkpoints.py checkpoints/official/OLMo-1B.csv --save-dir ./checkpoints/ --step 2000
```
**Note**: All checkpoints in `checkpoints/official/` are unsharded files.
**Note**: All checkpoints in `checkpoints/official/` are unsharded.

2. Resume training using the downloaded checkpoint. You can specify either a local path or URL using the --load_path argument: For example, to resume training from step 2000 of the OLMo 1B run:

```bash
torchrun --nproc_per_node=8 scripts/train.py configs/official/OLMo-1B.yaml --load_path=checkpoints/step2000 --save_folder=./new_checkpoints --run_name=olmo_test --save_overwrite
torchrun --nproc_per_node=8 scripts/train.py configs/official/OLMo-1B.yaml --load_path=checkpoints/step2000 --save_folder=./new_checkpoints --run_name=olmo_test
```
The command above:
- Loads the checkpoint from `checkpoints/step2000`
Expand All @@ -133,9 +133,6 @@ The command above:
- Overwrites existing checkpoints in the save folder.

### Inspecting training data

To inspect the exact tokens used in training batches for OLMo models, first download the training data. If you don't have an R2 API key, use the public HTTP URLs and update your configuration file with the local data paths. After completing this setup, you can use the inspection tools to examine the training batches.

Find the data order file URL in the [Models Overview](#models-overview) table. For example, the OLMo-7B model's first epoch data order file is located at [https://olmo-checkpoints.org/ai2-llm/olmo-medium/wvc30anm/train_data/global_indices.npy](https://olmo-checkpoints.org/ai2-llm/olmo-small/46zc5fly/train_data/global_indices.npy).
Once you have that you can use this snippet to inspect the data within a particular batch:

Expand Down
89 changes: 40 additions & 49 deletions scripts/download_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,55 +41,58 @@ def try_get_directory_listing(url):
"model.safetensors",
"optim.safetensors",
]

found_files = []
for pattern in common_files:
test_url = urljoin(url.rstrip('/') + '/', pattern)
try:
test_url = urljoin(url.rstrip('/') + '/', pattern)
response = requests.head(test_url)
# response.raise_for_status()
if response.status_code == 200:
found_files.append(pattern)
except requests.exceptions.RequestException:
continue

except requests.exceptions.HTTPError as e:
print(f"HTTP error for {pattern}: {e}")
except requests.exceptions.RequestException as e:
print(f"Connection error for {pattern}: {e}")
return found_files

def download_checkpoint(url, save_dir):
"""Download all files from a checkpoint directory."""
r2_url = convert_to_r2_url(url)
public_url = convert_to_public_url(r2_url)

base_path = Path(save_dir)
base_path.mkdir(parents=True, exist_ok=True)

print(f"\nR2 URL: {r2_url}")
print(f"Public URL: {public_url}")
print(f"Saving to: {base_path}")

print("Checking for available files...")
available_files = try_get_directory_listing(public_url)

if not available_files:
print("No files found using common patterns. The directory might be empty or use different file patterns.")
return

for file in available_files:
file_url = urljoin(public_url.rstrip('/') + '/', file)
file_path = base_path / file

try:
print(f"\nDownloading: {file}")
download_file(file_url, file_path)
except requests.exceptions.RequestException as e:
print(f"Error downloading {file}: {e}")
continue
"""Download all files from a checkpoint directory."""
r2_url = convert_to_r2_url(url)
public_url = convert_to_public_url(r2_url)
base_path = Path(save_dir)
base_path.mkdir(parents=True, exist_ok=True)
print(f"Saving to: {base_path}")
available_files = try_get_directory_listing(public_url)

if not available_files:
raise ValueError("No matching files found in directory")

failed_files = []
for file in available_files:
file_url = urljoin(public_url.rstrip('/') + '/', file)
file_path = base_path / file
try:
print(f"\nDownloading: {file}")
download_file(file_url, file_path)
except requests.exceptions.Timeout:
print(f"Timeout error for {file}, retrying once...")
try:
download_file(file_url, file_path)
except requests.exceptions.RequestException as e:
failed_files.append(file)
print(f"Failed to download {file}: {e}")
except requests.exceptions.RequestException as e:
failed_files.append(file)
print(f"Failed to download {file}: {e}")
if failed_files:
print(f"\nWARNING: Failed to download these files: {failed_files}")

def main():
parser = argparse.ArgumentParser(description='Download OLMo checkpoints from CSV')
parser.add_argument('csv_file', type=str, help='Path to the CSV file containing checkpoint URLs')
parser.add_argument('--save-dir', type=str, default='./checkpoints',
help='Base directory to save downloaded checkpoints')
parser.add_argument('--step', type=str, help='Specific step number to download (optional)')
parser.add_argument('--step', type=str, default='1000', help='Specific step number to download.')
parser.add_argument('--list-steps', action='store_true', help='List available step numbers and exit')

args = parser.parse_args()
Expand All @@ -101,7 +104,7 @@ def main():
urls = [(row['Step'], row['Checkpoint Directory']) for row in reader]

if args.list_steps:
print("\nAvailable steps:")
print("Available steps:")
for step, _ in urls:
print(f"Step {step}")
return
Expand All @@ -114,26 +117,14 @@ def main():
return

print(f"Saving checkpoints to: {args.save_dir}")
print("\nURL conversions:")
for step, url in urls:
r2_url = convert_to_r2_url(url)
public_url = convert_to_public_url(r2_url)
print(f"\nStep {step}:")
print(f"Original URL: {url}")
print(f"R2 URL: {r2_url}")
print(f"Public URL: {public_url}")

proceed = input("\nDo you want to proceed with the download? (y/n): ")
if proceed.lower() != 'y':
print("Download cancelled.")
return

for step, url in urls:
save_path = os.path.join(args.save_dir, f"step{step}")
try:
download_checkpoint(url, save_path)
except Exception as e:
print(f"Error during download of step {step}: {e}")
download_checkpoint(url, save_path)


if __name__ == "__main__":
main()
11 changes: 4 additions & 7 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,10 +268,9 @@ def dummy_init_fn(module: torch.nn.Module) -> None:
)
cfg.save_num_unsharded_checkpoints_to_keep = cfg.save_num_checkpoints_to_keep
elif cfg.distributed_strategy == DistributedStrategy.fsdp:
# checkpoint_type = (
# CheckpointType.sharded if cfg.save_num_checkpoints_to_keep != 0 else CheckpointType.unsharded
# )
checkpoint_type = CheckpointType.unsharded
checkpoint_type = (
CheckpointType.sharded if cfg.save_num_checkpoints_to_keep != 0 else CheckpointType.unsharded
)
else:
raise NotImplementedError(f"Distributed strategy {cfg.distributed_strategy} not supported yet!")

Expand All @@ -298,9 +297,7 @@ def dummy_init_fn(module: torch.nn.Module) -> None:
cfg.load_path,
load_optimizer_state=not cfg.reset_optimizer_state,
load_trainer_state=not cfg.reset_trainer_state,
# sharded_checkpointer=cfg.load_path_sharded_checkpointer,
sharded_checkpointer= False,
checkpoint_type=CheckpointType.unsharded
sharded_checkpointer=cfg.load_path_sharded_checkpointer,
)
log.info("Checkpoint successfully loaded")

Expand Down

0 comments on commit 8aac2ea

Please sign in to comment.