Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for Pathways proxy #690

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion axlearn/common/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,9 @@ def setup():
logging.info("Devices: %s", devices)
local_devices = jax.local_devices()
logging.info("Local Devices: %s", local_devices)
if not devices or not all(device.platform == FLAGS.jax_backend for device in devices):
if FLAGS.jax_backend != "proxy" and (
not devices or not all(device.platform == FLAGS.jax_backend for device in devices)
):
raise RuntimeError(f"Expected backend {FLAGS.jax_backend}. Got {devices}.")
if FLAGS.data_dir:
# TODO(ruoming): Get rid of --data_dir and use only env var DATA_DIR.
Expand Down
1 change: 1 addition & 0 deletions axlearn/common/launch_trainer_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

"""Main function for launching the trainer."""

import pathwaysutils # pylint: disable=unused-import
jesus-orozco marked this conversation as resolved.
Show resolved Hide resolved
from absl import app, flags

from axlearn.common import launch, launch_trainer, measurement
Expand Down
7 changes: 4 additions & 3 deletions axlearn/common/utils_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def setup(
if initialization_timeout is not None:
init_kwargs["initialization_timeout"] = initialization_timeout

if jax_backend == "tpu":
if jax_backend in ("tpu", "proxy"):
jesus-orozco marked this conversation as resolved.
Show resolved Hide resolved
if not (
distributed_coordinator is None and num_processes is None and process_id is None
):
Expand Down Expand Up @@ -115,5 +115,6 @@ def setup(
f"({initialization_timeout} seconds)."
)
else:
jax.distributed.initialize(**init_kwargs)
_jax_distributed_initialized = True
if jax_backend != "proxy":
jax.distributed.initialize(**init_kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So using pathways does not require jax.distributed.initialize? Please add a comment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rebased and added comment for pathways

_jax_distributed_initialized = True
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,11 @@ orbax = [
"orbax-checkpoint==0.5.23",
]

jesus-orozco marked this conversation as resolved.
Show resolved Hide resolved
# Pathways utilities.
pathways = [
"pathwaysutils@git+https://github.com/google/[email protected]", # for JAX+Pathways single-controller accelerator coordinator
]

[tool.flit.module]
# This defines the import name. https://flit.pypa.io/en/stable/pyproject_toml.html#module-section
name = "axlearn"
Expand Down