Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix azure provider breaking change #1341

Merged
merged 10 commits into from
Oct 23, 2023
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ jobs:
sudo apt-get install libsasl2-dev
- run:
name: Install Dependencies
command: pip install -U -e .[all,docs,mypy]
command: pip install -U -e .[docs]
- run:
name: Run Sphinx
command: |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@
default_args = {
"execution_timeout": timedelta(hours=EXECUTION_TIMEOUT),
"azure_data_factory_conn_id": "azure_data_factory_default",
"factory_name": DATAFACTORY_NAME, # This can also be specified in the ADF connection.
"resource_group_name": RESOURCE_GROUP_NAME, # This can also be specified in the ADF connection.
"retries": int(os.getenv("DEFAULT_TASK_RETRIES", 2)),
"retry_delay": timedelta(seconds=int(os.getenv("DEFAULT_RETRY_DELAY_SECONDS", 60))),
}
Expand Down Expand Up @@ -88,7 +86,7 @@ def create_adf_storage_pipeline() -> None:
df_resource = Factory(location=LOCATION)
df = adf_client.factories.create_or_update(RESOURCE_GROUP_NAME, DATAFACTORY_NAME, df_resource)
while df.provisioning_state != "Succeeded":
df = adf_client.factories.get(RESOURCE_GROUP_NAME, DATAFACTORY_NAME)
df = adf_client.factories.get(RESOURCE_GROUP_NAME, DATAFACTORY_NAME) # type: ignore[assignment]
time.sleep(1)

# Create an Azure Storage linked service
Expand All @@ -97,17 +95,17 @@ def create_adf_storage_pipeline() -> None:
storage_string = SecureString(value=CONNECTION_STRING)

ls_azure_storage = LinkedServiceResource(
properties=AzureStorageLinkedService(connection_string=storage_string)
properties=AzureStorageLinkedService(connection_string=storage_string) # type: ignore[arg-type]
)
adf_client.linked_services.create_or_update(
RESOURCE_GROUP_NAME, DATAFACTORY_NAME, STORAGE_LINKED_SERVICE_NAME, ls_azure_storage
)

# Create an Azure blob dataset (input)
ds_ls = LinkedServiceReference(reference_name=STORAGE_LINKED_SERVICE_NAME)
ds_ls = LinkedServiceReference(type="LinkedServiceReference", reference_name=STORAGE_LINKED_SERVICE_NAME)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to support/take care of backward compatibility for the previous versions of the provider?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And we used to install azure-mgmt-datafactory >= 1.0.0 here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also as an example DAG, I think we should use the newer version. WDYT?

ds_azure_blob = DatasetResource(
properties=AzureBlobDataset(
linked_service_name=ds_ls, folder_path=BLOB_PATH, file_name=BLOB_FILE_NAME
linked_service_name=ds_ls, folder_path=BLOB_PATH, file_name=BLOB_FILE_NAME # type: ignore[arg-type]
)
)
adf_client.datasets.create_or_update(
Expand All @@ -116,7 +114,7 @@ def create_adf_storage_pipeline() -> None:

# Create an Azure blob dataset (output)
ds_out_azure_blob = DatasetResource(
properties=AzureBlobDataset(linked_service_name=ds_ls, folder_path=OUTPUT_BLOB_PATH)
properties=AzureBlobDataset(linked_service_name=ds_ls, folder_path=OUTPUT_BLOB_PATH) # type: ignore[arg-type]
)
adf_client.datasets.create_or_update(
RESOURCE_GROUP_NAME, DATAFACTORY_NAME, DATASET_OUTPUT_NAME, ds_out_azure_blob
Expand All @@ -125,8 +123,8 @@ def create_adf_storage_pipeline() -> None:
# Create a copy activity
blob_source = BlobSource()
blob_sink = BlobSink()
ds_in_ref = DatasetReference(reference_name=DATASET_INPUT_NAME)
ds_out_ref = DatasetReference(reference_name=DATASET_OUTPUT_NAME)
ds_in_ref = DatasetReference(type="DatasetReference", reference_name=DATASET_INPUT_NAME)
ds_out_ref = DatasetReference(type="DatasetReference", reference_name=DATASET_OUTPUT_NAME)
Comment on lines +126 to +127
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same question for backward compatibility.

copy_activity = CopyActivity(
name=ACTIVITY_NAME, inputs=[ds_in_ref], outputs=[ds_out_ref], source=blob_source, sink=blob_sink
)
Expand Down Expand Up @@ -194,13 +192,17 @@ def delete_azure_data_factory_storage_pipeline() -> None:
run_pipeline_wait = AzureDataFactoryRunPipelineOperatorAsync(
task_id="run_pipeline_wait",
pipeline_name=PIPELINE_NAME,
factory_name=DATAFACTORY_NAME,
resource_group_name=RESOURCE_GROUP_NAME,
)
# [END howto_operator_adf_run_pipeline_async]

# [START howto_operator_adf_run_pipeline]
run_pipeline_no_wait = AzureDataFactoryRunPipelineOperatorAsync(
task_id="run_pipeline_no_wait",
pipeline_name=PIPELINE_NAME,
factory_name=DATAFACTORY_NAME,
resource_group_name=RESOURCE_GROUP_NAME,
wait_for_termination=False,
)
# [END howto_operator_adf_run_pipeline]
Expand All @@ -209,6 +211,8 @@ def delete_azure_data_factory_storage_pipeline() -> None:
pipeline_run_sensor_async = AzureDataFactoryPipelineRunStatusSensorAsync(
task_id="pipeline_run_sensor_async",
run_id=cast(str, XComArg(run_pipeline_wait, key="run_id")),
factory_name=DATAFACTORY_NAME,
resource_group_name=RESOURCE_GROUP_NAME,
)
# [END howto_sensor_pipeline_run_sensor_async]

Expand Down
12 changes: 6 additions & 6 deletions astronomer/providers/microsoft/azure/hooks/data_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,14 @@ class AzureDataFactoryHookAsync(AzureDataFactoryHook):

def __init__(self, azure_data_factory_conn_id: str):
"""Initialize the hook instance."""
self._async_conn: DataFactoryManagementClient = None
self._async_conn: DataFactoryManagementClient | None = None
self.conn_id = azure_data_factory_conn_id
super().__init__(azure_data_factory_conn_id=azure_data_factory_conn_id)

async def get_async_conn(self) -> DataFactoryManagementClient:
"""Get async connection and connect to azure data factory."""
if self._conn is not None:
return self._conn
return cast(DataFactoryManagementClient, self._conn) # pragma: no cover

conn = await sync_to_async(self.get_connection)(self.conn_id)
extras = conn.extra_dejson
Expand Down Expand Up @@ -113,8 +113,8 @@ async def get_async_conn(self) -> DataFactoryManagementClient:
async def get_pipeline_run(
self,
run_id: str,
resource_group_name: str | None = None,
factory_name: str | None = None,
resource_group_name: str,
factory_name: str,
**config: Any,
) -> PipelineRun:
"""
Expand All @@ -132,7 +132,7 @@ async def get_pipeline_run(
raise AirflowException(e)

async def get_adf_pipeline_run_status(
self, run_id: str, resource_group_name: str | None = None, factory_name: str | None = None
self, run_id: str, resource_group_name: str, factory_name: str
) -> str:
"""
Connect to Azure Data Factory asynchronously and get the pipeline status by run_id.
Expand All @@ -147,7 +147,7 @@ async def get_adf_pipeline_run_status(
factory_name=factory_name,
resource_group_name=resource_group_name,
)
status: str = pipeline_run.status
status: str = cast(str, pipeline_run.status)
return status
except Exception as e:
raise AirflowException(e)
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
AzureDataFactoryHook,
AzureDataFactoryPipelineRunException,
AzureDataFactoryPipelineRunStatus,
PipelineRunInfo,
)
from airflow.providers.microsoft.azure.operators.data_factory import (
AzureDataFactoryRunPipelineOperator,
Expand Down Expand Up @@ -67,12 +66,11 @@ def execute(self, context: Context) -> None:
context["ti"].xcom_push(key="run_id", value=run_id)
end_time = time.time() + self.timeout

pipeline_run_info = PipelineRunInfo(
pipeline_run_status = hook.get_pipeline_run_status(
run_id=run_id,
factory_name=self.factory_name,
resource_group_name=self.resource_group_name,
factory_name=self.factory_name,
)
pipeline_run_status = hook.get_pipeline_run_status(**pipeline_run_info)
if pipeline_run_status not in AzureDataFactoryPipelineRunStatus.TERMINAL_STATUSES:
self.defer(
timeout=self.execution_timeout,
Expand Down
10 changes: 5 additions & 5 deletions astronomer/providers/microsoft/azure/triggers/data_factory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import time
from typing import Any, AsyncIterator, Dict, List, Optional, Tuple
from typing import Any, AsyncIterator, Dict, List, Tuple

from airflow.providers.microsoft.azure.hooks.data_factory import (
AzureDataFactoryPipelineRunStatus,
Expand Down Expand Up @@ -29,8 +29,8 @@ def __init__(
run_id: str,
azure_data_factory_conn_id: str,
poke_interval: float,
resource_group_name: Optional[str] = None,
factory_name: Optional[str] = None,
resource_group_name: str,
factory_name: str,
):
super().__init__()
self.run_id = run_id
Expand Down Expand Up @@ -108,8 +108,8 @@ def __init__(
run_id: str,
azure_data_factory_conn_id: str,
end_time: float,
resource_group_name: Optional[str] = None,
factory_name: Optional[str] = None,
resource_group_name: str,
factory_name: str,
wait_for_termination: bool = True,
check_interval: int = 60,
):
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ mypy =
types-Markdown
types-PyMySQL
types-PyYAML
snowflake-connector-python>=3.3.0 # Temporary solution for fixing the issue that pip cannot find proper connector version

# All extras from above except 'mypy', 'docs' and 'tests'
all =
Expand Down
2 changes: 2 additions & 0 deletions tests/microsoft/azure/operators/test_data_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ class TestAzureDataFactoryRunPipelineOperatorAsync:
task_id="run_pipeline",
pipeline_name="pipeline",
parameters={"myParam": "value"},
factory_name="factory_name",
resource_group_name="resource_group",
)

@mock.patch(f"{MODULE}.AzureDataFactoryRunPipelineOperatorAsync.defer")
Expand Down
8 changes: 7 additions & 1 deletion tests/microsoft/azure/sensors/test_data_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ class TestAzureDataFactoryPipelineRunStatusSensorAsync:
SENSOR = AzureDataFactoryPipelineRunStatusSensorAsync(
task_id="pipeline_run_sensor_async",
run_id=RUN_ID,
factory_name="factory_name",
resource_group_name="resource_group_name",
)

@mock.patch(f"{MODULE}.AzureDataFactoryPipelineRunStatusSensorAsync.defer")
Expand Down Expand Up @@ -61,5 +63,9 @@ def test_poll_interval_deprecation_warning(self):
# TODO: Remove once deprecated
with pytest.warns(expected_warning=DeprecationWarning):
AzureDataFactoryPipelineRunStatusSensorAsync(
task_id="pipeline_run_sensor_async", run_id=self.RUN_ID, poll_interval=5.0
task_id="pipeline_run_sensor_async",
run_id=self.RUN_ID,
poll_interval=5.0,
factory_name="factory_name",
resource_group_name="resource_group_name",
)
Loading