Skip to content

Commit

Permalink
Fixes compat issue HTTPX proxy configuration in KiotaRequestAdapterHo…
Browse files Browse the repository at this point in the history
…ok and fixed retry in MSGraphSensor (apache#45746)


---------

Co-authored-by: David Blain <[email protected]>
  • Loading branch information
dabla and davidblain-infrabel authored Jan 19, 2025
1 parent e2da4c7 commit ee785a8
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

import httpx
from azure.identity import ClientSecretCredential
from httpx import Timeout
from httpx import AsyncHTTPTransport, Timeout
from kiota_abstractions.api_error import APIError
from kiota_abstractions.method import Method
from kiota_abstractions.request_information import RequestInformation
Expand Down Expand Up @@ -208,9 +208,9 @@ def format_no_proxy_url(url: str) -> str:
def to_httpx_proxies(cls, proxies: dict) -> dict:
proxies = proxies.copy()
if proxies.get("http"):
proxies["http://"] = proxies.pop("http")
proxies["http://"] = AsyncHTTPTransport(proxy=proxies.pop("http"))
if proxies.get("https"):
proxies["https://"] = proxies.pop("https")
proxies["https://"] = AsyncHTTPTransport(proxy=proxies.pop("https"))
if proxies.get("no"):
for url in proxies.pop("no", "").split(","):
proxies[cls.format_no_proxy_url(url.strip())] = None
Expand Down Expand Up @@ -288,7 +288,7 @@ def get_conn(self) -> RequestAdapter:
http_client = GraphClientFactory.create_with_default_middleware(
api_version=api_version, # type: ignore
client=httpx.AsyncClient(
proxy=httpx_proxies, # type: ignore
mounts=httpx_proxies,
timeout=Timeout(timeout=self.timeout),
verify=verify,
trust_env=trust_env,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def execute(self, context: Context):
def retry_execute(
self,
context: Context,
**kwargs,
) -> Any:
self.execute(context=context)

Expand Down
2 changes: 1 addition & 1 deletion providers/tests/microsoft/azure/resources/status.json
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"id": "0a1b1bf3-37de-48f7-9863-ed4cda97a9ef", "createdDateTime": "2024-04-10T15:05:17.357", "status": "Succeeded"}
[{"id": "0a1b1bf3-37de-48f7-9863-ed4cda97a9ef", "createdDateTime": "2024-04-10T15:05:17.357", "status": "InProgress"},{"id": "0a1b1bf3-37de-48f7-9863-ed4cda97a9ef", "createdDateTime": "2024-04-10T15:05:17.357", "status": "Succeeded"}]
27 changes: 21 additions & 6 deletions providers/tests/microsoft/azure/sensors/test_msgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from __future__ import annotations

import json
from datetime import datetime

import pytest

Expand All @@ -31,7 +32,7 @@
class TestMSGraphSensor(Base):
def test_execute(self):
status = load_json("resources", "status.json")
response = mock_json_response(200, status)
response = mock_json_response(200, *status)

with self.patch_hook_and_request_adapter(response):
sensor = MSGraphSensor(
Expand All @@ -40,6 +41,7 @@ def test_execute(self):
url="myorg/admin/workspaces/scanStatus/{scanId}",
path_parameters={"scanId": "0a1b1bf3-37de-48f7-9863-ed4cda97a9ef"},
result_processor=lambda context, result: result["id"],
retry_delay=5,
timeout=350.0,
)

Expand All @@ -48,16 +50,22 @@ def test_execute(self):
assert sensor.path_parameters == {"scanId": "0a1b1bf3-37de-48f7-9863-ed4cda97a9ef"}
assert isinstance(results, str)
assert results == "0a1b1bf3-37de-48f7-9863-ed4cda97a9ef"
assert len(events) == 1
assert len(events) == 3
assert isinstance(events[0], TriggerEvent)
assert events[0].payload["status"] == "success"
assert events[0].payload["type"] == "builtins.dict"
assert events[0].payload["response"] == json.dumps(status)
assert events[0].payload["response"] == json.dumps(status[0])
assert isinstance(events[1], TriggerEvent)
assert isinstance(events[1].payload, datetime)
assert isinstance(events[2], TriggerEvent)
assert events[2].payload["status"] == "success"
assert events[2].payload["type"] == "builtins.dict"
assert events[2].payload["response"] == json.dumps(status[1])

@pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Lambda parameters works in Airflow >= 2.10.0")
def test_execute_with_lambda_parameter(self):
status = load_json("resources", "status.json")
response = mock_json_response(200, status)
response = mock_json_response(200, *status)

with self.patch_hook_and_request_adapter(response):
sensor = MSGraphSensor(
Expand All @@ -66,6 +74,7 @@ def test_execute_with_lambda_parameter(self):
url="myorg/admin/workspaces/scanStatus/{scanId}",
path_parameters=lambda context, jinja_env: {"scanId": "0a1b1bf3-37de-48f7-9863-ed4cda97a9ef"},
result_processor=lambda context, result: result["id"],
retry_delay=5,
timeout=350.0,
)

Expand All @@ -74,11 +83,17 @@ def test_execute_with_lambda_parameter(self):
assert sensor.path_parameters == {"scanId": "0a1b1bf3-37de-48f7-9863-ed4cda97a9ef"}
assert isinstance(results, str)
assert results == "0a1b1bf3-37de-48f7-9863-ed4cda97a9ef"
assert len(events) == 1
assert len(events) == 3
assert isinstance(events[0], TriggerEvent)
assert events[0].payload["status"] == "success"
assert events[0].payload["type"] == "builtins.dict"
assert events[0].payload["response"] == json.dumps(status)
assert events[0].payload["response"] == json.dumps(status[0])
assert isinstance(events[1], TriggerEvent)
assert isinstance(events[1].payload, datetime)
assert isinstance(events[2], TriggerEvent)
assert events[2].payload["status"] == "success"
assert events[2].payload["type"] == "builtins.dict"
assert events[2].payload["response"] == json.dumps(status[1])

def test_template_fields(self):
sensor = MSGraphSensor(
Expand Down
6 changes: 4 additions & 2 deletions providers/tests/microsoft/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,10 @@ def xcom_pull(
run_id: str | None = None,
) -> Any:
if map_indexes:
return values.get(f"{task_ids or self.task_id}_{dag_id or self.dag_id}_{key}_{map_indexes}")
return values.get(f"{task_ids or self.task_id}_{dag_id or self.dag_id}_{key}")
return values.get(
f"{task_ids or self.task_id}_{dag_id or self.dag_id}_{key}_{map_indexes}", default
)
return values.get(f"{task_ids or self.task_id}_{dag_id or self.dag_id}_{key}", default)

def xcom_push(self, key: str, value: Any, session: Session = NEW_SESSION, **kwargs) -> None:
values[f"{self.task_id}_{self.dag_id}_{key}_{self.map_index}"] = value
Expand Down

0 comments on commit ee785a8

Please sign in to comment.