Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Enable flytekit to authenticate with proxy in front of FlyteAdmin #1787

Merged
merged 24 commits into from
Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
2794d00
Introduce authenticator engine and make proxy auth work
Jul 17, 2023
5d157e9
Use proxy authed session for client credentials flow
Jul 17, 2023
125a7ee
Don't use authenticator engine but do proxy authentication via existi…
Aug 11, 2023
ad8d556
Add docstring to AuthenticationHTTPAdapter
Aug 11, 2023
e0d0de9
Address todo in docstring
Aug 11, 2023
fc393a3
Create blank session if none provided
Aug 11, 2023
21710e0
Create blank session if none provided in get_token
Aug 11, 2023
397654c
Refresh proxy creds in session when not existing without triggering 401
Aug 11, 2023
6902036
Add test for get_session
Aug 11, 2023
977163d
Move auth helper test into existing module
Aug 11, 2023
eda1805
Move auth helper test into existing module
Aug 11, 2023
79b9dae
Add test for upgrade_channel_to_proxy_authenticated
Aug 11, 2023
40378c1
Auth helper tests without use of responses package
Aug 17, 2023
085319d
Feat: Add plugin for generating GCP IAP ID tokens via external comman…
fg91 Aug 24, 2023
a237066
Use proxy auth'ed session for device code auth flow
Aug 25, 2023
caa3653
Fix token client tests
Aug 25, 2023
2ebad82
Make poll token endpoint test more specific
Aug 25, 2023
f9ddc8c
Make test_client_creds_authenticator test work and more specific
Aug 25, 2023
c677ff3
Make test_client_creds_authenticator_with_custom_scopes test work and…
Aug 25, 2023
336aebc
Implement subcommand to generate id tokens for service accounts
fg91 Sep 7, 2023
7ad826d
Test id token generation from service accounts
fg91 Sep 9, 2023
b35988a
Fix plugin requirements
fg91 Sep 9, 2023
a713c56
Document usage of generate-service-account-id-token subcommand
Sep 10, 2023
2437b84
Document alternative ways to obtain service account id tokens
Sep 11, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 55 additions & 24 deletions flytekit/clients/auth/auth_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,11 @@ def __init__(
redirect_uri: typing.Optional[str] = None,
endpoint_metadata: typing.Optional[EndpointMetadata] = None,
verify: typing.Optional[typing.Union[bool, str]] = None,
session: typing.Optional[_requests.Session] = None,
request_auth_code_params: typing.Optional[typing.Dict[str, str]] = None,
request_access_token_params: typing.Optional[typing.Dict[str, str]] = None,
refresh_access_token_params: typing.Optional[typing.Dict[str, str]] = None,
add_request_auth_code_params_to_request_access_token_params: typing.Optional[bool] = False,
):
"""
Create new AuthorizationClient
Expand All @@ -192,7 +197,9 @@ def __init__(
:param auth_endpoint: str endpoint where auth metadata can be found
:param token_endpoint: str endpoint to retrieve token from
:param scopes: list[str] oauth2 scopes
:param client_id
:param client_id: oauth2 client id
:param redirect_uri: oauth2 redirect uri
:param endpoint_metadata: EndpointMetadata object to control the rendering of the page on login successful or failure
:param verify: (optional) Either a boolean, in which case it controls whether we verify
the server's TLS certificate, or a string, in which case it must be a path
to a CA bundle to use. Defaults to ``True``. When set to
Expand All @@ -201,6 +208,15 @@ def __init__(
certificates, which will make your application vulnerable to
man-in-the-middle (MitM) attacks. Setting verify to ``False``
may be useful during local development or testing.
:param session: (optional) A custom requests.Session object to use for making HTTP requests.
If not provided, a new Session object will be created.
:param request_auth_code_params: (optional) dict of parameters to add to login uri opened in the browser
:param request_access_token_params: (optional) dict of parameters to add when exchanging the auth code for the access token
:param refresh_access_token_params: (optional) dict of parameters to add when refreshing the access token
:param add_request_auth_code_params_to_request_access_token_params: Whether to add the `request_auth_code_params` to
the parameters sent when exchanging the auth code for the access token. Defaults to False.
Required e.g. for the PKCE flow with flyteadmin.
Not required for e.g. the standard OAuth2 flow on GCP.
"""
self._endpoint = endpoint
self._auth_endpoint = auth_endpoint
Expand All @@ -213,15 +229,13 @@ def __init__(
self._client_id = client_id
self._scopes = scopes or []
self._redirect_uri = redirect_uri
self._code_verifier = _generate_code_verifier()
code_challenge = _create_code_challenge(self._code_verifier)
self._code_challenge = code_challenge
state = _generate_state_parameter()
self._state = state
self._verify = verify
self._headers = {"content-type": "application/x-www-form-urlencoded"}
self._session = session or _requests.Session()

self._params = {
self._request_auth_code_params = {
"client_id": client_id, # This must match the Client ID of the OAuth application.
"response_type": "code", # Indicates the authorization code grant
"scope": " ".join(s.strip("' ") for s in self._scopes).strip(
Expand All @@ -230,10 +244,18 @@ def __init__(
# callback location where the user-agent will be directed to.
"redirect_uri": self._redirect_uri,
"state": state,
"code_challenge": code_challenge,
"code_challenge_method": "S256",
}

if request_auth_code_params:
# Allow adding additional parameters to the request_auth_code_params
self._request_auth_code_params.update(request_auth_code_params)

self._request_access_token_params = request_access_token_params or {}
self._refresh_access_token_params = refresh_access_token_params or {}

if add_request_auth_code_params_to_request_access_token_params:
self._request_access_token_params.update(self._request_auth_code_params)

def __repr__(self):
return f"AuthorizationClient({self._auth_endpoint}, {self._token_endpoint}, {self._client_id}, {self._scopes}, {self._redirect_uri})"

Expand All @@ -249,7 +271,7 @@ def _create_callback_server(self):

def _request_authorization_code(self):
scheme, netloc, path, _, _, _ = _urlparse.urlparse(self._auth_endpoint)
query = _urlencode(self._params)
query = _urlencode(self._request_auth_code_params)
endpoint = _urlparse.urlunparse((scheme, netloc, path, None, query, None))
logging.debug(f"Requesting authorization code through {endpoint}")
_webbrowser.open_new_tab(endpoint)
Expand All @@ -262,33 +284,38 @@ def _credentials_from_response(self, auth_token_resp) -> Credentials:
"refresh_token": "bar",
"token_type": "Bearer"
}

Can additionally contain "expires_in" and "id_token" fields.
"""
response_body = auth_token_resp.json()
refresh_token = None
id_token = None
if "access_token" not in response_body:
raise ValueError('Expected "access_token" in response from oauth server')
if "refresh_token" in response_body:
refresh_token = response_body["refresh_token"]
if "expires_in" in response_body:
expires_in = response_body["expires_in"]
access_token = response_body["access_token"]
if "id_token" in response_body:
id_token = response_body["id_token"]

return Credentials(access_token, refresh_token, self._endpoint, expires_in=expires_in)
return Credentials(access_token, refresh_token, self._endpoint, expires_in=expires_in, id_token=id_token)

def _request_access_token(self, auth_code) -> Credentials:
if self._state != auth_code.state:
raise ValueError(f"Unexpected state parameter [{auth_code.state}] passed")
self._params.update(
{
"code": auth_code.code,
"code_verifier": self._code_verifier,
"grant_type": "authorization_code",
}
)

resp = _requests.post(
params = {
"code": auth_code.code,
"grant_type": "authorization_code",
}

params.update(self._request_access_token_params)

resp = self._session.post(
url=self._token_endpoint,
data=self._params,
data=params,
headers=self._headers,
allow_redirects=False,
verify=self._verify,
Expand Down Expand Up @@ -332,13 +359,17 @@ def refresh_access_token(self, credentials: Credentials) -> Credentials:
if credentials.refresh_token is None:
raise ValueError("no refresh token available with which to refresh authorization credentials")

resp = _requests.post(
data = {
"refresh_token": credentials.refresh_token,
"grant_type": "refresh_token",
"client_id": self._client_id,
}

data.update(self._refresh_access_token_params)

resp = self._session.post(
url=self._token_endpoint,
data={
"grant_type": "refresh_token",
"client_id": self._client_id,
"refresh_token": credentials.refresh_token,
},
data=data,
headers=self._headers,
allow_redirects=False,
verify=self._verify,
Expand Down
32 changes: 31 additions & 1 deletion flytekit/clients/auth/authenticator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dataclasses import dataclass

import click
import requests

from . import token_client
from .auth_client import AuthorizationClient
Expand Down Expand Up @@ -95,16 +96,24 @@ def __init__(
cfg_store: ClientConfigStore,
header_key: typing.Optional[str] = None,
verify: typing.Optional[typing.Union[bool, str]] = None,
session: typing.Optional[requests.Session] = None,
):
"""
Initialize with default creds from KeyStore using the endpoint name
"""
super().__init__(endpoint, header_key, KeyringStore.retrieve(endpoint), verify=verify)
self._cfg_store = cfg_store
self._auth_client = None
self._session = session or requests.Session()

def _initialize_auth_client(self):
if not self._auth_client:

from .auth_client import _create_code_challenge, _generate_code_verifier

code_verifier = _generate_code_verifier()
code_challenge = _create_code_challenge(code_verifier)

cfg = self._cfg_store.get_client_config()
self._set_header_key(cfg.header_key)
self._auth_client = AuthorizationClient(
Expand All @@ -115,6 +124,16 @@ def _initialize_auth_client(self):
auth_endpoint=cfg.authorization_endpoint,
token_endpoint=cfg.token_endpoint,
verify=self._verify,
session=self._session,
request_auth_code_params={
"code_challenge": code_challenge,
"code_challenge_method": "S256",
},
request_access_token_params={
"code_verifier": code_verifier,
},
refresh_access_token_params={},
add_request_auth_code_params_to_request_access_token_params=True,
)

def refresh_credentials(self):
Expand Down Expand Up @@ -176,6 +195,7 @@ def __init__(
http_proxy_url: typing.Optional[str] = None,
verify: typing.Optional[typing.Union[bool, str]] = None,
audience: typing.Optional[str] = None,
session: typing.Optional[requests.Session] = None,
):
if not client_id or not client_secret:
raise ValueError("Client ID and Client SECRET both are required.")
Expand All @@ -186,6 +206,7 @@ def __init__(
self._client_id = client_id
self._client_secret = client_secret
self._audience = audience or cfg.audience
self._session = session or requests.Session()
super().__init__(endpoint, cfg.header_key or header_key, http_proxy_url=http_proxy_url, verify=verify)

def refresh_credentials(self):
Expand All @@ -211,6 +232,7 @@ def refresh_credentials(self):
verify=self._verify,
scopes=scopes,
audience=audience,
session=self._session,
)

logging.info("Retrieved new token, expires in {}".format(expires_in))
Expand All @@ -234,6 +256,7 @@ def __init__(
audience: typing.Optional[str] = None,
http_proxy_url: typing.Optional[str] = None,
verify: typing.Optional[typing.Union[bool, str]] = None,
session: typing.Optional[requests.Session] = None,
):
self._audience = audience
cfg = cfg_store.get_client_config()
Expand All @@ -245,6 +268,7 @@ def __init__(
raise AuthenticationError(
"Device Authentication is not available on the Flyte backend / authentication server"
)
self._session = session or requests.Session()
super().__init__(
endpoint=endpoint,
header_key=header_key or cfg.header_key,
Expand All @@ -255,7 +279,13 @@ def __init__(

def refresh_credentials(self):
resp = token_client.get_device_code(
self._device_auth_endpoint, self._client_id, self._audience, self._scope, self._http_proxy_url, self._verify
self._device_auth_endpoint,
self._client_id,
self._audience,
self._scope,
self._http_proxy_url,
self._verify,
self._session,
)
text = f"To Authenticate, navigate in a browser to the following URL: {click.style(resp.verification_uri, fg='blue', underline=True)} and enter code: {click.style(resp.user_code, fg='blue')}"
click.secho(text)
Expand Down
30 changes: 22 additions & 8 deletions flytekit/clients/auth/keyring.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass

import keyring as _keyring
from keyring.errors import NoKeyringError
from keyring.errors import NoKeyringError, PasswordDeleteError


@dataclass
Expand All @@ -16,6 +16,7 @@ class Credentials(object):
refresh_token: str = "na"
for_endpoint: str = "flyte-default"
expires_in: typing.Optional[int] = None
id_token: typing.Optional[str] = None


class KeyringStore:
Expand All @@ -25,20 +26,28 @@ class KeyringStore:

_access_token_key = "access_token"
_refresh_token_key = "refresh_token"
_id_token_key = "id_token"

@staticmethod
def store(credentials: Credentials) -> Credentials:
try:
_keyring.set_password(
credentials.for_endpoint,
KeyringStore._refresh_token_key,
credentials.refresh_token,
)
if credentials.refresh_token:
_keyring.set_password(
credentials.for_endpoint,
KeyringStore._refresh_token_key,
credentials.refresh_token,
)
_keyring.set_password(
credentials.for_endpoint,
KeyringStore._access_token_key,
credentials.access_token,
)
if credentials.id_token:
_keyring.set_password(
credentials.for_endpoint,
KeyringStore._id_token_key,
credentials.id_token,
)
except NoKeyringError as e:
logging.debug(f"KeyRing not available, tokens will not be cached. Error: {e}")
return credentials
Expand All @@ -48,18 +57,23 @@ def retrieve(for_endpoint: str) -> typing.Optional[Credentials]:
try:
refresh_token = _keyring.get_password(for_endpoint, KeyringStore._refresh_token_key)
access_token = _keyring.get_password(for_endpoint, KeyringStore._access_token_key)
id_token = _keyring.get_password(for_endpoint, KeyringStore._id_token_key)
except NoKeyringError as e:
logging.debug(f"KeyRing not available, tokens will not be cached. Error: {e}")
return None

if not access_token:
if not access_token and not id_token:
return None
return Credentials(access_token, refresh_token, for_endpoint)
return Credentials(access_token, refresh_token, for_endpoint, id_token=id_token)

@staticmethod
def delete(for_endpoint: str):
try:
_keyring.delete_password(for_endpoint, KeyringStore._access_token_key)
_keyring.delete_password(for_endpoint, KeyringStore._refresh_token_key)
try:
_keyring.delete_password(for_endpoint, KeyringStore._id_token_key)
except PasswordDeleteError as e:
logging.debug(f"Id token not found in key store, not deleting. Error: {e}")
except NoKeyringError as e:
logging.debug(f"KeyRing not available, tokens will not be cached. Error: {e}")
11 changes: 9 additions & 2 deletions flytekit/clients/auth/token_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def get_token(
grant_type: GrantType = GrantType.CLIENT_CREDS,
http_proxy_url: typing.Optional[str] = None,
verify: typing.Optional[typing.Union[bool, str]] = None,
session: typing.Optional[requests.Session] = None,
) -> typing.Tuple[str, int]:
"""
:rtype: (Text,Int) The first element is the access token retrieved from the IDP, the second is the expiration
Expand All @@ -103,7 +104,10 @@ def get_token(
body["audience"] = audience

proxies = {"https": http_proxy_url, "http": http_proxy_url} if http_proxy_url else None
response = requests.post(token_endpoint, data=body, headers=headers, proxies=proxies, verify=verify)

if not session:
session = requests.Session()
response = session.post(token_endpoint, data=body, headers=headers, proxies=proxies, verify=verify)

if not response.ok:
j = response.json()
Expand All @@ -125,6 +129,7 @@ def get_device_code(
scope: typing.Optional[typing.List[str]] = None,
http_proxy_url: typing.Optional[str] = None,
verify: typing.Optional[typing.Union[bool, str]] = None,
session: typing.Optional[requests.Session] = None,
) -> DeviceCodeResponse:
"""
Retrieves the device Authentication code that can be done to authenticate the request using a browser on a
Expand All @@ -133,7 +138,9 @@ def get_device_code(
_scope = " ".join(scope) if scope is not None else ""
payload = {"client_id": client_id, "scope": _scope, "audience": audience}
proxies = {"https": http_proxy_url, "http": http_proxy_url} if http_proxy_url else None
resp = requests.post(device_auth_endpoint, payload, proxies=proxies, verify=verify)
if not session:
session = requests.Session()
resp = session.post(device_auth_endpoint, payload, proxies=proxies, verify=verify)
if not resp.ok:
raise AuthenticationError(f"Unable to retrieve Device Authentication Code for {payload}, Reason {resp.reason}")
return DeviceCodeResponse.from_json_response(resp.json())
Expand Down
Loading