Skip to content

Commit

Permalink
Added support for certificate authentication with MSGraphAsyncOperator (
Browse files Browse the repository at this point in the history
apache#45935)

* refactor: Added support for certificate authentication in KiotaRequestAdapterHook

* refactor: Fixed label for allowed hosts in MS Graph connection form

---------

Co-authored-by: David Blain <[email protected]>
  • Loading branch information
dabla and davidblain-infrabel authored Jan 26, 2025
1 parent 29b9e8e commit ff1e3a6
Show file tree
Hide file tree
Showing 8 changed files with 247 additions and 23 deletions.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
135 changes: 135 additions & 0 deletions docs/apache-airflow-providers-microsoft-azure/connections/msgraph.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
.. Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
.. http://www.apache.org/licenses/LICENSE-2.0
.. Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
.. _howto/connection:msgraph:

Microsoft Graph API Connection
==============================

The Microsoft Graph API connection type enables Microsoft Graph API Integrations.

The :class:`~airflow.providers.microsoft.azure.hooks.msgraph.KiotaRequestAdapterHook` and :class:`~airflow.providers.microsoft.azure.operators.msgraph.MSGraphAsyncOperator` requires a connection of type ``msgraph`` to authenticate with Microsoft Graph API.

Authenticating to Microsoft Graph API
-------------------------------------

1. Use `token credentials
<https://docs.microsoft.com/en-us/azure/developer/python/azure-sdk-authenticate?tabs=cmd#authenticate-with-token-credentials>`_
i.e. add specific credentials (client_id, client_secret, tenant_id) to the Airflow connection.

Default Connection IDs
----------------------

All hooks and operators related to Microsoft Graph API use ``msgraph_default`` by default.

Configuring the Connection
--------------------------

Client ID
Specify the ``client_id`` used for the initial connection.
This is needed for *token credentials* authentication mechanism.


Client Secret
Specify the ``client_secret`` used for the initial connection.
This is needed for *token credentials* authentication mechanism unless a certificate is used.


Tenant ID
Specify the ``tenant_id`` used for the initial connection.
This is needed for *token credentials* authentication mechanism.


API Version
Specify the ``api_version`` used for the initial connection.
Default value is ``v1.0``.


Authority
The ``authority`` parameter defines the endpoint (or tenant) that MSAL uses to authenticate requests.
It determines which identity provider will handle authentication.
Default value is ``login.microsoftonline.com``.


Scopes
The ``scopes`` parameter specifies the permissions or access rights that your application is requesting for a connection.
These permissions define what resources or data your application can access on behalf of the user or application.
Default value is ``https://graph.microsoft.com/.default``.


Certificate path
The ``certificate_path`` parameter specifies the filepath where the certificate is located.
Both ``certificate_path`` and ``certificate_data`` parameter cannot be used together, they should be mutually exclusive.
Default value is None.


Certificate data
The ``certificate_date`` parameter specifies the certificate as a string.
Both ``certificate_path`` and ``certificate_data`` parameter cannot be used together, they should be mutually exclusive.
Default value is None.


Disable instance discovery
The ``disable_instance_discovery`` parameter determines whether MSAL should validate and discover Azure AD endpoints dynamically during runtime.
Default value is False (e.g. disabled).


Allowed hosts
The ``allowed_hosts`` parameter is used to define a list of acceptable hosts that the authentication provider will trust when making requests.
This parameter is particularly useful for enhancing security and controlling which endpoints the authentication provider interacts with.


Proxies
The ``proxies`` parameter is used to define a dict for the ``http`` and ``https`` schema, the ``no`` key can be use to define hosts not to be used by the proxy.
Default value is None.


Verify environment
The ``verify`` parameter specifies whether SSL certificates should be verified when making HTTPS requests.
By default, ``verify`` parameter is set to True. This means that the `httpx <https://www.python-httpx.org>`_ library will verify the SSL certificate presented by the server to ensure:

- The certificate is valid and trusted.
- The certificate matches the hostname of the server.
- The certificate has not expired or been revoked.

Setting ``verify`` to False disables SSL certificate verification. This is typically used in development or testing environments when working with self-signed certificates or servers without valid certificates.


Trust environment
The ``trust_env`` parameter determines whether or not the library should use environment variables for configuration when making HTTP/HTTPS requests.
By default, ``trust_env`` parameter is set to True. This means the `httpx <https://www.python-httpx.org>`_ library will automatically trust and use environment variables for proxy configuration, SSL settings, and authentication.


Base URL
The ``base_url`` parameter allows you to override the default base url used to make it requests, namely ``https://graph.microsoft.com/``.
This can be useful if you want to use the MSGraphAsyncOperator to call other Microsoft REST API's like Sharepoint or PowerBI.
Default value is None.


.. raw:: html

<div align="center" style="padding-bottom:10px">
<img src="images/msgraph.png"
alt="Microsoft Graph API connection form">
</div>


.. spelling:word-list::
Entra
83 changes: 63 additions & 20 deletions providers/src/airflow/providers/microsoft/azure/hooks/msgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from urllib.parse import quote, urljoin, urlparse

import httpx
from azure.identity import ClientSecretCredential
from azure.identity import CertificateCredential, ClientSecretCredential
from httpx import AsyncHTTPTransport, Timeout
from kiota_abstractions.api_error import APIError
from kiota_abstractions.method import Method
Expand All @@ -47,6 +47,7 @@
from airflow.hooks.base import BaseHook

if TYPE_CHECKING:
from azure.identity._internal.client_credential_base import ClientCredentialBase
from kiota_abstractions.request_adapter import RequestAdapter
from kiota_abstractions.request_information import QueryParams
from kiota_abstractions.response_handler import NativeResponseType
Expand Down Expand Up @@ -107,6 +108,7 @@ class KiotaRequestAdapterHook(BaseHook):
"""

DEFAULT_HEADERS = {"Accept": "application/json;q=1"}
DEFAULT_SCOPE = "https://graph.microsoft.com/.default"
cached_request_adapters: dict[str, tuple[APIVersion, RequestAdapter]] = {}
conn_type: str = "msgraph"
conn_name_attr: str = "conn_id"
Expand All @@ -119,15 +121,18 @@ def __init__(
timeout: float | None = None,
proxies: dict | None = None,
host: str = NationalClouds.Global.value,
scopes: list[str] | None = None,
scopes: str | list[str] | None = None,
api_version: APIVersion | str | None = None,
):
super().__init__()
self.conn_id = conn_id
self.timeout = timeout
self.proxies = proxies
self.host = host
self.scopes = scopes or ["https://graph.microsoft.com/.default"]
if isinstance(scopes, str):
self.scopes = [scopes]
else:
self.scopes = scopes or [self.DEFAULT_SCOPE]
self._api_version = self.resolve_api_version_from_value(api_version)

@classmethod
Expand All @@ -140,20 +145,21 @@ def get_connection_form_widgets(cls) -> dict[str, Any]:
return {
"tenant_id": StringField(lazy_gettext("Tenant ID"), widget=BS3TextFieldWidget()),
"api_version": StringField(
lazy_gettext("API Version"), widget=BS3TextFieldWidget(), default="v1.0"
lazy_gettext("API Version"), widget=BS3TextFieldWidget(), default=APIVersion.v1.value
),
"authority": StringField(lazy_gettext("Authority"), widget=BS3TextFieldWidget()),
"certificate_path": StringField(lazy_gettext("Certificate path"), widget=BS3TextFieldWidget()),
"certificate_data": StringField(lazy_gettext("Certificate data"), widget=BS3TextFieldWidget()),
"scopes": StringField(
lazy_gettext("Scopes"),
widget=BS3TextFieldWidget(),
default="https://graph.microsoft.com/.default",
default=cls.DEFAULT_SCOPE,
),
"disable_instance_discovery": BooleanField(
lazy_gettext("Disable instance discovery"), default=False
),
"allowed_hosts": StringField(lazy_gettext("Allowed"), widget=BS3TextFieldWidget()),
"allowed_hosts": StringField(lazy_gettext("Allowed hosts"), widget=BS3TextFieldWidget()),
"proxies": StringField(lazy_gettext("Proxies"), widget=BS3TextAreaFieldWidget()),
"stream": BooleanField(lazy_gettext("Stream"), default=False),
"verify": BooleanField(lazy_gettext("Verify"), default=True),
"trust_env": BooleanField(lazy_gettext("Trust environment"), default=True),
"base_url": StringField(lazy_gettext("Base URL"), widget=BS3TextFieldWidget()),
Expand Down Expand Up @@ -241,18 +247,17 @@ def get_conn(self) -> RequestAdapter:
client_id = connection.login
client_secret = connection.password
config = connection.extra_dejson if connection.extra else {}
tenant_id = config.get("tenant_id") or config.get("tenantId")
api_version = self.get_api_version(config)
host = self.get_host(connection)
base_url = config.get("base_url", urljoin(host, api_version))
authority = config.get("authority")
proxies = self.proxies or config.get("proxies", {})
msal_proxies = self.to_msal_proxies(authority=authority, proxies=proxies)
httpx_proxies = self.to_httpx_proxies(proxies=proxies)
scopes = config.get("scopes", self.scopes)
if isinstance(scopes, str):
scopes = scopes.split(",")
verify = config.get("verify", True)
trust_env = config.get("trust_env", False)
disable_instance_discovery = config.get("disable_instance_discovery", False)
allowed_hosts = (config.get("allowed_hosts", authority) or "").split(",")

self.log.info(
Expand All @@ -262,7 +267,6 @@ def get_conn(self) -> RequestAdapter:
)
self.log.info("Host: %s", host)
self.log.info("Base URL: %s", base_url)
self.log.info("Tenant id: %s", tenant_id)
self.log.info("Client id: %s", client_id)
self.log.info("Client secret: %s", client_secret)
self.log.info("API version: %s", api_version)
Expand All @@ -271,19 +275,16 @@ def get_conn(self) -> RequestAdapter:
self.log.info("Timeout: %s", self.timeout)
self.log.info("Trust env: %s", trust_env)
self.log.info("Authority: %s", authority)
self.log.info("Disable instance discovery: %s", disable_instance_discovery)
self.log.info("Allowed hosts: %s", allowed_hosts)
self.log.info("Proxies: %s", proxies)
self.log.info("MSAL Proxies: %s", msal_proxies)
self.log.info("HTTPX Proxies: %s", httpx_proxies)
credentials = ClientSecretCredential(
tenant_id=tenant_id, # type: ignore
client_id=connection.login,
client_secret=connection.password,
credentials = self.get_credentials(
login=connection.login,
password=connection.password,
config=config,
authority=authority,
proxies=msal_proxies,
disable_instance_discovery=disable_instance_discovery,
connection_verify=verify,
verify=verify,
proxies=proxies,
)
http_client = GraphClientFactory.create_with_default_middleware(
api_version=api_version, # type: ignore
Expand Down Expand Up @@ -313,6 +314,48 @@ def get_conn(self) -> RequestAdapter:
self._api_version = api_version
return request_adapter

def get_credentials(
self,
login: str | None,
password: str | None,
config,
authority: str | None,
verify: bool,
proxies: dict,
) -> ClientCredentialBase:
tenant_id = config.get("tenant_id") or config.get("tenantId")
certificate_path = config.get("certificate_path")
certificate_data = config.get("certificate_data")
disable_instance_discovery = config.get("disable_instance_discovery", False)
msal_proxies = self.to_msal_proxies(authority=authority, proxies=proxies)
self.log.info("Tenant id: %s", tenant_id)
self.log.info("Certificate path: %s", certificate_path)
self.log.info("Certificate data: %s", certificate_data is not None)
self.log.info("Authority: %s", authority)
self.log.info("Disable instance discovery: %s", disable_instance_discovery)
self.log.info("MSAL Proxies: %s", msal_proxies)
if certificate_path or certificate_data:
return CertificateCredential(
tenant_id=tenant_id, # type: ignore
client_id=login, # type: ignore
password=password,
certificate_path=certificate_path,
certificate_data=certificate_data.encode() if certificate_data else None,
authority=authority,
proxies=msal_proxies,
disable_instance_discovery=disable_instance_discovery,
connection_verify=verify,
)
return ClientSecretCredential(
tenant_id=tenant_id, # type: ignore
client_id=login, # type: ignore
client_secret=password, # type: ignore
authority=authority,
proxies=msal_proxies,
disable_instance_discovery=disable_instance_discovery,
connection_verify=verify,
)

def test_connection(self):
"""Test HTTP Connection."""
try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class MSGraphAsyncOperator(BaseOperator):
:param timeout: The HTTP timeout being used by the `KiotaRequestAdapter` (default is None).
When no timeout is specified or set to None then there is no HTTP timeout on each request.
:param proxies: A dict defining the HTTP proxies to be used (default is None).
:param scopes: The scopes to be used (default is ["https://graph.microsoft.com/.default"]).
:param api_version: The API version of the Microsoft Graph API to be used (default is v1).
You can pass an enum named APIVersion which has 2 possible members v1 and beta,
or you can pass a string as `v1.0` or `beta`.
Expand Down Expand Up @@ -110,6 +111,7 @@ def __init__(
key: str = XCOM_RETURN_KEY,
timeout: float | None = None,
proxies: dict | None = None,
scopes: str | list[str] | None = None,
api_version: APIVersion | str | None = None,
pagination_function: Callable[[MSGraphAsyncOperator, dict, Context], tuple[str, dict]] | None = None,
result_processor: Callable[[Context, Any], Any] = lambda context, result: result,
Expand All @@ -130,6 +132,7 @@ def __init__(
self.key = key
self.timeout = timeout
self.proxies = proxies
self.scopes = scopes
self.api_version = api_version
self.pagination_function = pagination_function or self.paginate
self.result_processor = result_processor
Expand All @@ -150,6 +153,7 @@ def execute(self, context: Context) -> None:
conn_id=self.conn_id,
timeout=self.timeout,
proxies=self.proxies,
scopes=self.scopes,
api_version=self.api_version,
serializer=type(self.serializer),
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class MSGraphSensor(BaseSensorOperator):
:param method: The HTTP method being used to do the REST call (default is GET).
:param conn_id: The HTTP Connection ID to run the operator against (templated).
:param proxies: A dict defining the HTTP proxies to be used (default is None).
:param scopes: The scopes to be used (default is ["https://graph.microsoft.com/.default"]).
:param api_version: The API version of the Microsoft Graph API to be used (default is v1).
You can pass an enum named APIVersion which has 2 possible members v1 and beta,
or you can pass a string as `v1.0` or `beta`.
Expand Down Expand Up @@ -83,6 +84,7 @@ def __init__(
data: dict[str, Any] | str | BytesIO | None = None,
conn_id: str = KiotaRequestAdapterHook.default_conn_name,
proxies: dict | None = None,
scopes: str | list[str] | None = None,
api_version: APIVersion | str | None = None,
event_processor: Callable[[Context, Any], bool] = lambda context, e: e.get("status") == "Succeeded",
result_processor: Callable[[Context, Any], Any] = lambda context, result: result,
Expand All @@ -101,6 +103,7 @@ def __init__(
self.data = data
self.conn_id = conn_id
self.proxies = proxies
self.scopes = scopes
self.api_version = api_version
self.event_processor = event_processor
self.result_processor = result_processor
Expand All @@ -120,6 +123,7 @@ def execute(self, context: Context):
conn_id=self.conn_id,
timeout=self.timeout,
proxies=self.proxies,
scopes=self.scopes,
api_version=self.api_version,
serializer=type(self.serializer),
),
Expand Down
Loading

0 comments on commit ff1e3a6

Please sign in to comment.