-
Notifications
You must be signed in to change notification settings - Fork 514
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Documentation Improvements #745
Changes from 3 commits
a622fb0
8aac2ea
c21087d
4e256a9
71abc2c
c904429
36ba37a
2448127
ccfb06d
930daaa
e1e54d9
b2f7ffc
e4786af
46cfcce
5d2fbb7
796de60
206da7c
889aaaa
d867ced
973b34d
b3324b5
dc3cfe1
8c34f59
4fdc829
5da6e3d
a40d46e
3b0139d
d520823
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,7 +32,6 @@ doc/_build/ | |
*.swp | ||
.DS_Store | ||
|
||
|
||
# python | ||
|
||
*.pyc | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,130 @@ | ||||||
import csv | ||||||
import os | ||||||
import requests | ||||||
from tqdm import tqdm | ||||||
import argparse | ||||||
from pathlib import Path | ||||||
from urllib.parse import urljoin | ||||||
|
||||||
def convert_to_r2_url(http_url): | ||||||
"""Convert HTTP URL to R2 URL format.""" | ||||||
if http_url.startswith('https://olmo-checkpoints.org/'): | ||||||
return http_url.replace('https://olmo-checkpoints.org/', 'r2://olmo-checkpoints/') | ||||||
return http_url | ||||||
|
||||||
def convert_to_public_url(r2_url): | ||||||
"""Convert R2 URL to public HTTP URL format.""" | ||||||
if r2_url.startswith('r2://olmo-checkpoints/'): | ||||||
return r2_url.replace('r2://olmo-checkpoints/', 'https://olmo-checkpoints.org/') | ||||||
return r2_url | ||||||
|
||||||
def download_file(url, save_path, chunk_size=8192): | ||||||
"""Download a file with progress bar.""" | ||||||
response = requests.get(url, stream=True) | ||||||
response.raise_for_status() | ||||||
total_size = int(response.headers.get('content-length', 0)) | ||||||
save_path.parent.mkdir(parents=True, exist_ok=True) | ||||||
|
||||||
with open(save_path, 'wb') as f: | ||||||
with tqdm(total=total_size, unit='B', unit_scale=True, desc=save_path.name) as pbar: | ||||||
for chunk in response.iter_content(chunk_size=chunk_size): | ||||||
if chunk: | ||||||
f.write(chunk) | ||||||
pbar.update(len(chunk)) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Retries should be handled in here. If the request fails with a recoverable error code (anything that starts with 5XX, 408, 409, 429), it should wait one second, and then try again. Try up to 5 times, and then give up. When giving up, it must make sure that the file at |
||||||
|
||||||
def try_get_directory_listing(url): | ||||||
common_files = [ | ||||||
"config.yaml", | ||||||
"model.pt", | ||||||
"optim.pt", | ||||||
"train.pt", | ||||||
"model.safetensors", | ||||||
"optim.safetensors", | ||||||
] | ||||||
found_files = [] | ||||||
for pattern in common_files: | ||||||
try: | ||||||
test_url = urljoin(url.rstrip('/') + '/', pattern) | ||||||
response = requests.head(test_url) | ||||||
# response.raise_for_status() | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Leftover debugging code? |
||||||
if response.status_code == 200: | ||||||
found_files.append(pattern) | ||||||
except requests.exceptions.HTTPError as e: | ||||||
print(f"HTTP error for {pattern}: {e}") | ||||||
except requests.exceptions.RequestException as e: | ||||||
print(f"Connection error for {pattern}: {e}") | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is still swallowing exceptions. It just prints about them. This part of the code should check whether the exception is a 404 error, in which case it's fine, and otherwise let the exception propagate. |
||||||
return found_files | ||||||
|
||||||
def download_checkpoint(url, save_dir): | ||||||
"""Download all files from a checkpoint directory.""" | ||||||
r2_url = convert_to_r2_url(url) | ||||||
public_url = convert_to_public_url(r2_url) | ||||||
base_path = Path(save_dir) | ||||||
base_path.mkdir(parents=True, exist_ok=True) | ||||||
print(f"Saving to: {base_path}") | ||||||
available_files = try_get_directory_listing(public_url) | ||||||
|
||||||
if not available_files: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How does this make any better? found_files can never be negative right? |
||||||
raise ValueError("No matching files found in directory") | ||||||
|
||||||
failed_files = [] | ||||||
for file in available_files: | ||||||
file_url = urljoin(public_url.rstrip('/') + '/', file) | ||||||
file_path = base_path / file | ||||||
try: | ||||||
print(f"\nDownloading: {file}") | ||||||
download_file(file_url, file_path) | ||||||
except requests.exceptions.Timeout: | ||||||
print(f"Timeout error for {file}, retrying once...") | ||||||
try: | ||||||
download_file(file_url, file_path) | ||||||
except requests.exceptions.RequestException as e: | ||||||
failed_files.append(file) | ||||||
print(f"Failed to download {file}: {e}") | ||||||
except requests.exceptions.RequestException as e: | ||||||
failed_files.append(file) | ||||||
print(f"Failed to download {file}: {e}") | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't retry in this way. The Python requests library already supports retries. In this code we should just call |
||||||
if failed_files: | ||||||
print(f"\nWARNING: Failed to download these files: {failed_files}") | ||||||
|
||||||
def main(): | ||||||
parser = argparse.ArgumentParser(description='Download OLMo checkpoints from CSV') | ||||||
parser.add_argument('csv_file', type=str, help='Path to the CSV file containing checkpoint URLs') | ||||||
parser.add_argument('--save-dir', type=str, default='./checkpoints', | ||||||
help='Base directory to save downloaded checkpoints') | ||||||
parser.add_argument('--step', type=str, default='1000', help='Specific step number to download.') | ||||||
parser.add_argument('--list-steps', action='store_true', help='List available step numbers and exit') | ||||||
aman-17 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
args = parser.parse_args() | ||||||
|
||||||
print(f"Reading CSV file: {args.csv_file}") | ||||||
|
||||||
with open(args.csv_file, 'r') as f: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That looks like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, without csv we can't download the checkpoints since it contains links. |
||||||
reader = csv.DictReader(f) | ||||||
urls = [(row['Step'], row['Checkpoint Directory']) for row in reader] | ||||||
|
||||||
if args.list_steps: | ||||||
print("Available steps:") | ||||||
for step, _ in urls: | ||||||
print(f"Step {step}") | ||||||
return | ||||||
|
||||||
if args.step: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
urls = [(step, url) for step, url in urls if step == args.step] | ||||||
if not urls: | ||||||
print(f"Error: Step {args.step} not found in the CSV file.") | ||||||
print("Use --list-steps to see available step numbers.") | ||||||
return | ||||||
|
||||||
print(f"Saving checkpoints to: {args.save_dir}") | ||||||
for step, url in urls: | ||||||
r2_url = convert_to_r2_url(url) | ||||||
public_url = convert_to_public_url(r2_url) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The URLs in the CSV are already public URLs? Why are we doing this change? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've taken these functions from my another file where I will have r2 urls in csv while uploading. I missed this here. |
||||||
print(f"\nStep {step}:") | ||||||
print(f"Public URL: {public_url}") | ||||||
save_path = os.path.join(args.save_dir, f"step{step}") | ||||||
download_checkpoint(url, save_path) | ||||||
|
||||||
|
||||||
if __name__ == "__main__": | ||||||
main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
train.safetensors
? Also, for the original model we just have*.pt
so we should have that format mentioned somewhere.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are going to save in .safetensors starting from OLMo-2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, but people might still try to use older OLMo models. The documentation should be backwards-compatible?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, not backwards compatible