Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MPS support #789

Merged
merged 25 commits into from
Feb 11, 2025
Merged
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
69aebc8
working on mps_support
aman-17 Jan 16, 2025
60bbe0f
update
aman-17 Jan 17, 2025
264b73a
add mps support in main
aman-17 Jan 17, 2025
94f594c
cleaned code
aman-17 Jan 18, 2025
e961d7a
small fix
aman-17 Jan 19, 2025
bbc6c78
added readme and more cleaning
aman-17 Jan 21, 2025
3b340ac
updated main function, readme and more cleaning
aman-17 Jan 21, 2025
03d08a0
Merge branch 'main' into amanr/mps_support
aman-17 Jan 22, 2025
9dfa9d6
fixed style and type_check for PR
aman-17 Jan 22, 2025
0c2e341
fixed iblack sort
aman-17 Jan 22, 2025
f793ed9
updated changelog
aman-17 Jan 22, 2025
9a75350
fixed pytest error and updated changelog
aman-17 Jan 29, 2025
b18d2bd
removed precision, batch_size from main function
aman-17 Jan 29, 2025
1c07659
Merge branch 'main' into amanr/mps_support
dirkgr Feb 4, 2025
53f23c8
Resolved all the comments mentioned by Dirk
aman-17 Feb 5, 2025
cd8fbd4
fixed style
aman-17 Feb 5, 2025
b0de86e
fixed changelog conflict
aman-17 Feb 5, 2025
168fabe
Merge branch 'main' into amanr/mps_support
aman-17 Feb 5, 2025
0e016e8
updated changelog for MPS support
aman-17 Feb 5, 2025
cc56dd2
resolved Dirk's comments
aman-17 Feb 5, 2025
bc8ee29
updated log.info to log.warning for RNG warning
aman-17 Feb 5, 2025
cb30dc7
added `get_autocast_dtype("mps")` for attention bias
aman-17 Feb 5, 2025
4259de2
passed `weights_only=False` while loading checkpoints
aman-17 Feb 5, 2025
e192068
removed `.to(device)` for checkpoint loading
aman-17 Feb 6, 2025
dd26d23
updated `_http_get_bytes_range` to resolve Sudden data error during t…
aman-17 Feb 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add mps support in main
  • Loading branch information
aman-17 committed Jan 17, 2025
commit 264b73ab23bacc04d4c6b898fb7f6035e64e875e
6 changes: 4 additions & 2 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def dummy_init_fn(module: torch.nn.Module) -> None:
dist.init_process_group(
backend="nccl", timeout=timedelta(minutes=30), device_id=torch.device(device_as_string)
)
else:
elif torch.backends.mps.is_available():
if not os.getenv("RANK"):
os.environ["RANK"] = "0"
if not os.getenv("WORLD_SIZE"):
Expand All @@ -411,9 +411,11 @@ def dummy_init_fn(module: torch.nn.Module) -> None:
os.environ["MASTER_ADDR"] = "0.0.0.0"
if not os.getenv("MASTER_PORT"):
os.environ["MASTER_PORT"] = "24500"

dist.init_process_group(backend="gloo", timeout=timedelta(minutes=30))

else:
dist.init_process_group(backend="gloo", timeout=timedelta(minutes=30))

log.info("Process group initialized")

prepare_cli_environment()
Expand Down