Skip to content

Add GKE A3 Ultra support #940

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 21 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ COPY . .

FROM base AS gpu

# Needed for NVIDIA CX7 based RDMA (not cloud specific).
RUN apt-get update && apt-get install -y ibverbs-utils
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!


# TODO(markblee): Support extras.
ENV PIP_FIND_LINKS=https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
RUN pip install .[core,gpu]
Expand Down
9 changes: 7 additions & 2 deletions axlearn/cloud/gcp/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from axlearn.cloud.common.utils import subprocess_run
from axlearn.cloud.gcp.config import default_env_id, default_project, default_zone, gcp_settings
from axlearn.cloud.gcp.jobset_utils import (
A3ReplicatedJob,
AcceleratorConfig,
BaseReplicatedJob,
TPUReplicatedJob,
Expand Down Expand Up @@ -304,6 +303,10 @@ def from_flags(cls, fv: flags.FlagValues, **kwargs) -> Config:
cfg.builder = cls.builder.from_flags(fv, **kwargs)
return cfg

@classmethod
def with_builder(cls, builder: type[BaseReplicatedJob]):
return type(f"{cls.__name__}_{builder.__name__}", (cls,), {"builder": builder})

def __init__(self, cfg):
bundler_cfg = cfg.bundler
bundler_cfg = getattr(bundler_cfg, "inner", bundler_cfg)
Expand Down Expand Up @@ -395,9 +398,11 @@ class GPUGKEJob(GKEJob):
"""A GPU job represented as a k8s JobSet.

See also `gke_runner` as an example.

Builder is set dynamically based on the instance type.
e.g. GKEJob.with_builder(A3UltraReplicatedJob))
"""

builder = A3ReplicatedJob
Config = GKEJob.Config


Expand Down
41 changes: 18 additions & 23 deletions axlearn/cloud/gcp/job_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@
from axlearn.cloud.gcp.bundler import ArtifactRegistryBundler, CloudBuildBundler, GCSTarBundler
from axlearn.cloud.gcp.config import gcp_settings
from axlearn.cloud.gcp.job import CPUJob, TPUQRMJob, _kill_ssh_agent, _start_ssh_agent
from axlearn.cloud.gcp.jobset_utils import (
A3HighReplicatedJob,
A3UltraReplicatedJob,
BaseReplicatedJob,
)
from axlearn.cloud.gcp.jobset_utils_test import mock_settings
from axlearn.cloud.gcp.test_utils import mock_gcp_settings
from axlearn.cloud.gcp.tpu import create_queued_tpu, delete_queued_tpu, infer_tpu_type, qrm_resource
Expand Down Expand Up @@ -278,6 +283,7 @@ class GPUGKEJobTest(TestCase):
@contextlib.contextmanager
def _job_config(
self,
replicated_job_cls: type[BaseReplicatedJob],
bundler_cls: type[Bundler],
service_account: Optional[str] = None,
queue: Optional[str] = None,
Expand All @@ -288,13 +294,13 @@ def _job_config(
[job.__name__, jobset_utils.__name__, bundler.__name__], mock_settings()
):
fv = flags.FlagValues()
job.GPUGKEJob.define_flags(fv)
job.GPUGKEJob.with_builder(replicated_job_cls).define_flags(fv)
if service_account:
fv.set_default("service_account", service_account)
if num_replicas:
fv.set_default("num_replicas", num_replicas)
fv.mark_as_parsed()
cfg = job.GPUGKEJob.from_flags(fv)
cfg = job.GPUGKEJob.with_builder(replicated_job_cls).from_flags(fv)
cfg.bundler = bundler_cls.from_spec([], fv=fv).set(image="test-image")
cfg.accelerator.instance_type = "gpu-a3-highgpu-8g-256"
cfg.queue = queue
Expand All @@ -304,6 +310,7 @@ def _job_config(
yield cfg

@parameterized.product(
replicated_job_cls=[A3HighReplicatedJob, A3UltraReplicatedJob],
service_account=[None, "sa"],
queue=[None, "queue-name"],
bundler_cls=[ArtifactRegistryBundler, CloudBuildBundler],
Expand All @@ -312,7 +319,14 @@ def _job_config(
env_vars=[None, {"a": "b"}],
)
def test_instantiate(
self, service_account, bundler_cls, wrap_bundler, num_replicas, env_vars, queue
self,
replicated_job_cls,
service_account,
bundler_cls,
wrap_bundler,
num_replicas,
env_vars,
queue,
):
class WrappedBundler(Bundler):
@config_class
Expand All @@ -321,6 +335,7 @@ class Config(Bundler.Config):

settings = mock_settings()
with self._job_config(
replicated_job_cls,
bundler_cls,
service_account=service_account,
env_vars=env_vars,
Expand Down Expand Up @@ -351,26 +366,6 @@ class Config(Bundler.Config):
else:
self.assertEqual(num_replicas, job_cfg.accelerator.num_replicas)

@parameterized.product(
bundler_cls=[ArtifactRegistryBundler, CloudBuildBundler],
queue=[None, "queue-name"],
)
def test_build_jobset(
self,
bundler_cls,
queue: Optional[str] = None,
):
with self._job_config(bundler_cls, queue=queue) as cfg:
gke_job: job.GPUGKEJob = cfg.set(name="test").instantiate()
# pylint: disable-next=protected-access
jobset = gke_job._build_jobset()
jobset_annotations = jobset["metadata"]["annotations"]
self.assertEqual(jobset["metadata"]["name"], cfg.name)
if queue is None:
self.assertNotIn("kueue.x-k8s.io/queue-name", jobset_annotations)
else:
self.assertEqual(jobset_annotations["kueue.x-k8s.io/queue-name"], queue)


if __name__ == "__main__":
_private_flags()
Expand Down
26 changes: 14 additions & 12 deletions axlearn/cloud/gcp/jobs/gke_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,14 @@
from axlearn.cloud.gcp.bundler import ArtifactRegistryBundler
from axlearn.cloud.gcp.config import gcp_settings
from axlearn.cloud.gcp.event_queue import event_queue_from_config
from axlearn.cloud.gcp.job import GCPJob, GKEJob, GPUGKEJob, TPUGKEJob
from axlearn.cloud.gcp.job import GCPJob, GKEJob, TPUGKEJob
from axlearn.cloud.gcp.jobs import runner_utils
from axlearn.cloud.gcp.jobs.tpu_runner import with_tpu_training_defaults
from axlearn.cloud.gcp.jobset_utils import BASTION_JOB_VERSION_LABEL
from axlearn.cloud.gcp.jobset_utils import (
BASTION_JOB_VERSION_LABEL,
A3HighReplicatedJob,
A3UltraReplicatedJob,
)
from axlearn.cloud.gcp.node_pool import (
PRE_PROVISIONER_LABEL,
delete_node_pools,
Expand Down Expand Up @@ -143,6 +147,10 @@ def validate_inner(cls):
if cls.inner is None:
raise ValueError(f"A GKERunnerJob should subclass {cls} and define `inner`.")

@classmethod
def with_inner(cls, inner: type[GKEJob]):
return type(f"{cls.__name__}_{inner.__name__}", (cls,), {"inner": inner})

@classmethod
def define_flags(cls, fv: flags.FlagValues = FLAGS):
super().define_flags(fv)
Expand Down Expand Up @@ -531,19 +539,13 @@ def from_flags(cls, fv: flags.FlagValues, **kwargs):
return cfg


class GPUGKERunnerJob(GKERunnerJob):
"""A GKERunnerJob that uses GPUGKEJob."""

inner = GPUGKEJob


def _get_runner_or_exit(instance_type: str):
if instance_type.startswith("tpu"):
return TPUGKERunnerJob
elif instance_type.startswith("gpu-a3"):
# TODO(markblee): We can directly construct:
# GKERunnerJob.with_inner(GKEJob.with_jobset(A3ReplicatedJob))
return GPUGKERunnerJob
elif instance_type.startswith("gpu-a3-ultra"):
return GKERunnerJob.with_inner(GKEJob.with_builder(A3UltraReplicatedJob))
elif instance_type.startswith("gpu-a3-high"):
return GKERunnerJob.with_inner(GKEJob.with_builder(A3HighReplicatedJob))
Comment on lines +545 to +548
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Neat!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is like praising yourself since you came up with it lol

else:
raise app.UsageError(f"Unknown instance_type {instance_type}")

Expand Down
158 changes: 9 additions & 149 deletions axlearn/cloud/gcp/jobs/gke_runner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from axlearn.cloud.common.bastion import BASTION_JOB_VERSION_ENV_VAR
from axlearn.cloud.gcp import bundler, node_pool_provisioner
from axlearn.cloud.gcp.job import GPUGKEJob, TPUGKEJob
from axlearn.cloud.gcp.job import TPUGKEJob
from axlearn.cloud.gcp.jobs import gke_runner
from axlearn.cloud.gcp.jobs.bastion_vm_test import _mock_job
from axlearn.cloud.gcp.jobs.gke_runner import (
Expand Down Expand Up @@ -55,151 +55,6 @@ def _mock_replicated_jobs(reservations: Sequence[str], bastion_job_version: Opti
]


class GPUGKERunnerJobTest(parameterized.TestCase):
"""Tests GPUGKERunnerJob."""

@contextlib.contextmanager
def _job_config(
self,
*,
name: str,
cluster: str,
service_account: str,
gcsfuse_mount_spec: Optional[str] = None,
) -> Iterator[tuple[gke_runner.GPUGKERunnerJob.Config, dict]]:
mock_user = mock.patch("os.environ", {"USER": "test"})
mock_settings = {
"project": "settings-project",
"zone": "settings-zone-a",
"ttl_bucket": "settings-ttl-bucket",
"gke_cluster": "settings-cluster",
"default_dockerfile": "settings-dockerfile",
"docker_repo": "settings-repo",
}
with (
mock_user,
mock_gcp_settings(
[gke_runner.__name__, bundler.__name__, node_pool_provisioner.__name__],
mock_settings,
),
):
fv = flags.FlagValues()
gke_runner.GPUGKERunnerJob.define_flags(fv)
if name:
fv.set_default("name", name)
if cluster:
fv.set_default("cluster", cluster)
if service_account:
fv.set_default("service_account", service_account)
if gcsfuse_mount_spec:
fv.set_default("gcsfuse_mount_spec", gcsfuse_mount_spec)
fv.set_default("instance_type", "gpu-a3-highgpu-8g-256")
fv.mark_as_parsed()
yield gke_runner.GPUGKERunnerJob.from_flags(fv), mock_settings

@parameterized.product(
name=[None, "test-name"],
cluster=[None, "test-cluster"],
service_account=[None, "test-sa"],
gcsfuse_mount_spec=[None, ["gcs_path=my-test-path"]],
)
def test_from_flags(self, name, cluster, service_account, gcsfuse_mount_spec):
with self._job_config(
name=name,
cluster=cluster,
service_account=service_account,
gcsfuse_mount_spec=gcsfuse_mount_spec,
) as (cfg, mock_settings):
if name:
self.assertEqual(cfg.name, name)
else:
self.assertIsNotNone(cfg.name)
self.assertEqual(cfg.cluster, cluster or mock_settings["gke_cluster"])
self.assertEqual(cfg.service_account, service_account or "default")
if gcsfuse_mount_spec:
fuse = cast(GPUGKEJob.Config, cfg.inner).builder.gcsfuse_mount
self.assertEqual(fuse.gcs_path, "my-test-path")

@parameterized.product(
status=[
gke_runner.GKERunnerJob.Status.FAILED,
gke_runner.GKERunnerJob.Status.SUCCEEDED,
gke_runner.GKERunnerJob.Status.COMPLETED,
],
)
def test_exit(self, status):
with self._job_config(
name="test-name",
cluster="test-cluster",
service_account="test-sa",
) as (cfg, _):
cfg.bundler.set(image="test")
job: gke_runner.GPUGKERunnerJob = cfg.set(command="").instantiate()

mock_job = mock.patch.multiple(
job, _get_status=mock.Mock(return_value=status), _delete=mock.DEFAULT
)

with mock_job:
job._execute()

def test_delete(self):
with self._job_config(
name="test-name",
cluster="test-cluster",
service_account="test-sa",
) as (cfg, _):
cfg.bundler.set(image="test")

job: gke_runner.GPUGKERunnerJob = cfg.set(
command="", status_interval_seconds=0
).instantiate()

mock_job = mock.patch.multiple(
job,
_inner=mock.DEFAULT,
_pre_provisioner=mock.DEFAULT,
)

with mock_job:
job._delete()
job._inner._delete.assert_called() # pytype: disable=attribute-error

def test_start(self):
with self._job_config(
name="test-name",
cluster="test-cluster",
service_account="test-sa",
) as (
cfg,
_,
):
cfg.bundler.set(image="test")

job: gke_runner.GPUGKERunnerJob = cfg.set(
command="",
status_interval_seconds=0,
).instantiate()

mock_job = mock.patch.multiple(
job,
_get_status=mock.Mock(
side_effect=[
gke_runner.GKERunnerJob.Status.NOT_STARTED,
gke_runner.GKERunnerJob.Status.COMPLETED,
]
),
_get_job_credentials=mock.DEFAULT,
_delete=mock.DEFAULT,
_inner=mock.DEFAULT,
_pre_provisioner=mock.DEFAULT,
)

with mock_job:
job._execute()
job._inner.execute.assert_called() # pytype: disable=attribute-error


class TPUGKERunnerJobTest(parameterized.TestCase):
"""Tests TPUGKERunnerJob."""

Expand Down Expand Up @@ -1009,20 +864,25 @@ class MainTest(parameterized.TestCase):
@parameterized.parameters(
dict(instance_type="tpu", expected=gke_runner.TPUGKERunnerJob),
dict(instance_type="tpu-v4-8", expected=gke_runner.TPUGKERunnerJob),
dict(instance_type="gpu-a3-highgpu-8g-256", expected=gke_runner.GPUGKERunnerJob),
dict(instance_type="gpu-a3-highgpu-8g-256", expected=gke_runner.GKERunnerJob),
dict(instance_type="gpu", expected=app.UsageError("instance_type")),
)
def test_get_runner_or_exit(self, instance_type: str, expected: Union[Exception, type]):
if isinstance(expected, Exception):
with self.assertRaisesRegex(type(expected), str(expected)):
_get_runner_or_exit(instance_type)
else:
self.assertEqual(expected, _get_runner_or_exit(instance_type))
actual_runner = _get_runner_or_exit(instance_type)
# For GPU cases, check that it is a subclass of GKERunnerJob.
if instance_type.startswith("gpu-"):
self.assertTrue(issubclass(actual_runner, expected))
else:
self.assertEqual(expected, actual_runner)

@parameterized.product(
[
dict(runner=gke_runner.TPUGKERunnerJob, instance_type="tpu-v4-8"),
dict(runner=gke_runner.GPUGKERunnerJob, instance_type="gpu-a3-highgpu-8g-256"),
dict(runner=gke_runner.GKERunnerJob, instance_type="gpu-a3-highgpu-8g-256"),
],
action=["start", "stop", "update"],
)
Expand Down
Loading
Loading