Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

docker login at runtime #710

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 86 additions & 1 deletion WDL/runtime/backend/docker_swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,27 @@
"""

import os
import re
import json
import stat
import time
import shlex
import uuid
import shlex
import base64
import random
import hashlib
import logging
import warnings
import threading
import traceback
import contextlib
from enum import Enum
from io import BytesIO
from typing import List, Dict, Set, Optional, Any, Callable, Tuple, Iterable
import boto3
import docker
import google.auth
import google.auth.transport.requests
from ... import Error
from ..._util import chmod_R_plus, TerminationSignalFlag
from ..._util import StructuredLogMessage as _
Expand All @@ -26,6 +32,15 @@
from ..task_container import TaskContainer


logging.getLogger("botocore").setLevel(logging.WARNING)


class SupportedProviders(Enum):
AWS = "aws"
GCP = "gcp"
UNKNOWN = None


class SwarmContainer(TaskContainer):
"""
TaskContainer docker (swarm) runtime
Expand Down Expand Up @@ -325,6 +340,20 @@ def resolve_tag(
try:
image_attrs = client.images.get(image_tag).attrs
except docker.errors.ImageNotFound:
try:
# docker.errors.APIError is thrown if permissions are missing
client.images.get_registry_data(image_tag) # type: ignore[attr-defined]
except docker.errors.APIError:
logger.debug(f"Need to login to {image_tag} registry")
registry_name, provider = self.get_registry_name_and_provider(logger, image_tag)
if registry_name and provider is SupportedProviders.AWS:
self.aws_ecr_login(logger, client, registry_name)
if registry_name and provider is SupportedProviders.GCP:
self.gcp_docker_registry_login(client, registry_name)
if provider is SupportedProviders.UNKNOWN:
logger.warning(
f"{image_tag} registry pattern unrecognized. If login is needed do it before running the workflow"
)
try:
logger.info(_("docker pull", tag=image_tag))
client.images.pull(image_tag)
Expand All @@ -342,6 +371,62 @@ def resolve_tag(
logger.notice(_("docker image", **image_log))
return image_tag

def get_registry_name_and_provider(
self, logger: logging.Logger, image_tag: str
) -> Tuple[str | None, SupportedProviders]:
logger.debug(f"Get registry name and provider for {image_tag}")
# GCP:
# - <LOCATION>-docker.pkg.dev/<PROJECT-ID>/<REPOSITORY>
# - <LOCATION>.gcr.io/<PROJECT-ID> (legacy)
gcp_registry_pattern = (
r"^(?P<gcp>[a-z-]+[0-9]+-docker\.pkg\.dev/[a-z0-9-]+/[a-z0-9-]+|[a-z\.]*gcr\.io)/.*$"
)
# AWS:
# - <AWS_ACCOUNT_ID>.dkr.ecr.<REGION>.amazonaws.com
aws_registry_pattern = r"^(?P<aws>[0-9]{12}\.dkr\.ecr\.[a-z-]+[0-9]+\.amazonaws\.com)/.*$"

pattern_match = re.match(gcp_registry_pattern, image_tag) or re.match(
aws_registry_pattern, image_tag
)
registry_name = pattern_match.group(1) if pattern_match else None
provider = SupportedProviders(
list(pattern_match.groupdict().keys())[0] if pattern_match else None
)
logger.debug(f"Registry: {registry_name}. Provider: {provider}")
return registry_name, provider

def aws_ecr_login(
self, logger: logging.Logger, docker_client: docker.DockerClient, registry_name: str
) -> None:
logger.debug(f"Get region and account ID from registry name {registry_name}")
aws_account_id, _, _, aws_region, _, _ = registry_name.split(".")
logger.debug(f"AWS account: {aws_account_id}. Region: {aws_region}")
ecr_client = boto3.client("ecr", region_name=aws_region)
logger.debug(f"Get ECR token for {registry_name}")
response = ecr_client.get_authorization_token(registryIds=[aws_account_id])
ecr_password = (
base64.b64decode(response["authorizationData"][0]["authorizationToken"])
.replace(b"AWS:", b"")
.decode("utf-8")
)
logger.debug(f"Login to {registry_name}")
self.docker_login(docker_client, "AWS", ecr_password, registry_name)

def gcp_docker_registry_login(self, client: docker.DockerClient, registry_name: str) -> None:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
creds, _ = google.auth.default(
scopes=["https://www.googleapis.com/auth/cloud-platform"]
)
auth_req = google.auth.transport.requests.Request()
creds.refresh(auth_req)
self.docker_login(client, "oauth2accesstoken", creds.token, registry_name)

def docker_login(
self, client: docker.DockerClient, username: str, password: str, registry_name: str
) -> None:
client.login(username, password, registry=registry_name, reauth=True) # type: ignore[attr-defined]

def prepare_mounts(self, logger: logging.Logger) -> List[docker.types.Mount]:
def escape(s):
# docker processes {{ interpolations }}
Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ dependencies = [
"python-json-logger>=2,<3",
"lark~=1.1",
"bullet>=2,<3",
"psutil>=5,<7"
"psutil>=5,<7",
"google-auth>=2.32.0",
"boto3>=1.34.153",
"boto3-stubs>=1.34.153",
]

[project.optional-dependencies]
Expand Down