diff --git a/providers/src/airflow/providers/google/cloud/operators/bigquery.py b/providers/src/airflow/providers/google/cloud/operators/bigquery.py index bc950ec9f3f0c..46a9d008a22c8 100644 --- a/providers/src/airflow/providers/google/cloud/operators/bigquery.py +++ b/providers/src/airflow/providers/google/cloud/operators/bigquery.py @@ -29,7 +29,7 @@ from google.api_core.exceptions import Conflict from google.cloud.bigquery import DEFAULT_RETRY, CopyJob, ExtractJob, LoadJob, QueryJob, Row -from google.cloud.bigquery.table import RowIterator +from google.cloud.bigquery.table import RowIterator, Table, TableReference from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException @@ -1339,6 +1339,7 @@ def __init__( self.cluster_fields = cluster_fields self.table_resource = table_resource self.impersonation_chain = impersonation_chain + self._table: Table | None = None if exists_ok is not None: warnings.warn( "`exists_ok` parameter is deprecated, please use `if_exists`", @@ -1369,6 +1370,7 @@ def execute(self, context: Context) -> None: try: self.log.info("Creating table") + # Save table as attribute for further use by OpenLineage self._table = bq_hook.create_empty_table( project_id=self.project_id, dataset_id=self.dataset_id, @@ -1414,7 +1416,8 @@ def execute(self, context: Context) -> None: BigQueryTableLink.persist(**persist_kwargs) - def get_openlineage_facets_on_complete(self, task_instance): + def get_openlineage_facets_on_complete(self, _): + """Implement _on_complete as we will use table resource returned by create method.""" from airflow.providers.common.compat.openlineage.facet import Dataset from airflow.providers.google.cloud.openlineage.utils import ( BIGQUERY_NAMESPACE, @@ -1422,11 +1425,13 @@ def get_openlineage_facets_on_complete(self, task_instance): ) from airflow.providers.openlineage.extractors import OperatorLineage - table_info = self._table.to_api_repr()["tableReference"] - table_id = ".".join((table_info["projectId"], table_info["datasetId"], table_info["tableId"])) + if not self._table: + self.log.debug("OpenLineage did not find `self._table` attribute.") + return OperatorLineage() + output_dataset = Dataset( namespace=BIGQUERY_NAMESPACE, - name=table_id, + name=f"{self._table.project}.{self._table.dataset_id}.{self._table.table_id}", facets=get_facets_from_bq_table(self._table), ) @@ -1649,6 +1654,7 @@ def __init__( self.encryption_configuration = encryption_configuration self.location = location self.impersonation_chain = impersonation_chain + self._table: Table | None = None def execute(self, context: Context) -> None: bq_hook = BigQueryHook( @@ -1657,15 +1663,16 @@ def execute(self, context: Context) -> None: impersonation_chain=self.impersonation_chain, ) if self.table_resource: + # Save table as attribute for further use by OpenLineage self._table = bq_hook.create_empty_table( table_resource=self.table_resource, ) BigQueryTableLink.persist( context=context, task_instance=self, - dataset_id=self._table.to_api_repr()["tableReference"]["datasetId"], - project_id=self._table.to_api_repr()["tableReference"]["projectId"], - table_id=self._table.to_api_repr()["tableReference"]["tableId"], + dataset_id=self._table.dataset_id, + project_id=self._table.project, + table_id=self._table.table_id, ) return @@ -1716,19 +1723,19 @@ def execute(self, context: Context) -> None: "encryptionConfiguration": self.encryption_configuration, } - self._table = bq_hook.create_empty_table( - table_resource=table_resource, - ) + # Save table as attribute for further use by OpenLineage + self._table = bq_hook.create_empty_table(table_resource=table_resource) BigQueryTableLink.persist( context=context, task_instance=self, - dataset_id=self._table.to_api_repr()["tableReference"]["datasetId"], - project_id=self._table.to_api_repr()["tableReference"]["projectId"], - table_id=self._table.to_api_repr()["tableReference"]["tableId"], + dataset_id=self._table.dataset_id, + project_id=self._table.project, + table_id=self._table.table_id, ) - def get_openlineage_facets_on_complete(self, task_instance): + def get_openlineage_facets_on_complete(self, _): + """Implement _on_complete as we will use table resource returned by create method.""" from airflow.providers.common.compat.openlineage.facet import Dataset from airflow.providers.google.cloud.openlineage.utils import ( BIGQUERY_NAMESPACE, @@ -1736,11 +1743,9 @@ def get_openlineage_facets_on_complete(self, task_instance): ) from airflow.providers.openlineage.extractors import OperatorLineage - table_info = self._table.to_api_repr()["tableReference"] - table_id = ".".join((table_info["projectId"], table_info["datasetId"], table_info["tableId"])) output_dataset = Dataset( namespace=BIGQUERY_NAMESPACE, - name=table_id, + name=f"{self._table.project}.{self._table.dataset_id}.{self._table.table_id}", facets=get_facets_from_bq_table(self._table), ) @@ -2133,6 +2138,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.table_resource = table_resource self.impersonation_chain = impersonation_chain + self._table: dict | None = None super().__init__(**kwargs) def execute(self, context: Context): @@ -2141,7 +2147,8 @@ def execute(self, context: Context): impersonation_chain=self.impersonation_chain, ) - table = bq_hook.update_table( + # Save table as attribute for further use by OpenLineage + self._table = bq_hook.update_table( table_resource=self.table_resource, fields=self.fields, dataset_id=self.dataset_id, @@ -2152,12 +2159,30 @@ def execute(self, context: Context): BigQueryTableLink.persist( context=context, task_instance=self, - dataset_id=table["tableReference"]["datasetId"], - project_id=table["tableReference"]["projectId"], - table_id=table["tableReference"]["tableId"], + dataset_id=self._table["tableReference"]["datasetId"], + project_id=self._table["tableReference"]["projectId"], + table_id=self._table["tableReference"]["tableId"], + ) + + return self._table + + def get_openlineage_facets_on_complete(self, _): + """Implement _on_complete as we will use table resource returned by update method.""" + from airflow.providers.common.compat.openlineage.facet import Dataset + from airflow.providers.google.cloud.openlineage.utils import ( + BIGQUERY_NAMESPACE, + get_facets_from_bq_table, ) + from airflow.providers.openlineage.extractors import OperatorLineage - return table + table = Table.from_api_repr(self._table) + output_dataset = Dataset( + namespace=BIGQUERY_NAMESPACE, + name=f"{table.project}.{table.dataset_id}.{table.table_id}", + facets=get_facets_from_bq_table(table), + ) + + return OperatorLineage(outputs=[output_dataset]) class BigQueryUpdateDatasetOperator(GoogleCloudBaseOperator): @@ -2291,15 +2316,47 @@ def __init__( self.ignore_if_missing = ignore_if_missing self.location = location self.impersonation_chain = impersonation_chain + self.hook: BigQueryHook | None = None def execute(self, context: Context) -> None: self.log.info("Deleting: %s", self.deletion_dataset_table) - hook = BigQueryHook( + # Save hook as attribute for further use by OpenLineage + self.hook = BigQueryHook( gcp_conn_id=self.gcp_conn_id, location=self.location, impersonation_chain=self.impersonation_chain, ) - hook.delete_table(table_id=self.deletion_dataset_table, not_found_ok=self.ignore_if_missing) + self.hook.delete_table(table_id=self.deletion_dataset_table, not_found_ok=self.ignore_if_missing) + + def get_openlineage_facets_on_complete(self, _): + """Implement _on_complete as we need default project_id from hook.""" + from airflow.providers.common.compat.openlineage.facet import ( + Dataset, + LifecycleStateChange, + LifecycleStateChangeDatasetFacet, + PreviousIdentifier, + ) + from airflow.providers.google.cloud.openlineage.utils import BIGQUERY_NAMESPACE + from airflow.providers.openlineage.extractors import OperatorLineage + + bq_table_id = str( + TableReference.from_string(self.deletion_dataset_table, default_project=self.hook.project_id) + ) + ds = Dataset( + namespace=BIGQUERY_NAMESPACE, + name=bq_table_id, + facets={ + "lifecycleStateChange": LifecycleStateChangeDatasetFacet( + lifecycleStateChange=LifecycleStateChange.DROP.value, + previousIdentifier=PreviousIdentifier( + namespace=BIGQUERY_NAMESPACE, + name=bq_table_id, + ), + ) + }, + ) + + return OperatorLineage(inputs=[ds]) class BigQueryUpsertTableOperator(GoogleCloudBaseOperator): @@ -2358,6 +2415,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.location = location self.impersonation_chain = impersonation_chain + self._table: dict | None = None def execute(self, context: Context) -> None: self.log.info("Upserting Dataset: %s with table_resource: %s", self.dataset_id, self.table_resource) @@ -2366,7 +2424,8 @@ def execute(self, context: Context) -> None: location=self.location, impersonation_chain=self.impersonation_chain, ) - table = hook.run_table_upsert( + # Save table as attribute for further use by OpenLineage + self._table = hook.run_table_upsert( dataset_id=self.dataset_id, table_resource=self.table_resource, project_id=self.project_id, @@ -2374,11 +2433,29 @@ def execute(self, context: Context) -> None: BigQueryTableLink.persist( context=context, task_instance=self, - dataset_id=table["tableReference"]["datasetId"], - project_id=table["tableReference"]["projectId"], - table_id=table["tableReference"]["tableId"], + dataset_id=self._table["tableReference"]["datasetId"], + project_id=self._table["tableReference"]["projectId"], + table_id=self._table["tableReference"]["tableId"], ) + def get_openlineage_facets_on_complete(self, _): + """Implement _on_complete as we will use table resource returned by upsert method.""" + from airflow.providers.common.compat.openlineage.facet import Dataset + from airflow.providers.google.cloud.openlineage.utils import ( + BIGQUERY_NAMESPACE, + get_facets_from_bq_table, + ) + from airflow.providers.openlineage.extractors import OperatorLineage + + table = Table.from_api_repr(self._table) + output_dataset = Dataset( + namespace=BIGQUERY_NAMESPACE, + name=f"{table.project}.{table.dataset_id}.{table.table_id}", + facets=get_facets_from_bq_table(table), + ) + + return OperatorLineage(outputs=[output_dataset]) + class BigQueryUpdateTableSchemaOperator(GoogleCloudBaseOperator): """ @@ -2466,6 +2543,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain self.location = location + self._table: dict | None = None super().__init__(**kwargs) def execute(self, context: Context): @@ -2473,7 +2551,8 @@ def execute(self, context: Context): gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, location=self.location ) - table = bq_hook.update_table_schema( + # Save table as attribute for further use by OpenLineage + self._table = bq_hook.update_table_schema( schema_fields_updates=self.schema_fields_updates, include_policy_tags=self.include_policy_tags, dataset_id=self.dataset_id, @@ -2484,11 +2563,29 @@ def execute(self, context: Context): BigQueryTableLink.persist( context=context, task_instance=self, - dataset_id=table["tableReference"]["datasetId"], - project_id=table["tableReference"]["projectId"], - table_id=table["tableReference"]["tableId"], + dataset_id=self._table["tableReference"]["datasetId"], + project_id=self._table["tableReference"]["projectId"], + table_id=self._table["tableReference"]["tableId"], ) - return table + return self._table + + def get_openlineage_facets_on_complete(self, _): + """Implement _on_complete as we will use table resource returned by update method.""" + from airflow.providers.common.compat.openlineage.facet import Dataset + from airflow.providers.google.cloud.openlineage.utils import ( + BIGQUERY_NAMESPACE, + get_facets_from_bq_table, + ) + from airflow.providers.openlineage.extractors import OperatorLineage + + table = Table.from_api_repr(self._table) + output_dataset = Dataset( + namespace=BIGQUERY_NAMESPACE, + name=f"{table.project}.{table.dataset_id}.{table.table_id}", + facets=get_facets_from_bq_table(table), + ) + + return OperatorLineage(outputs=[output_dataset]) class BigQueryInsertJobOperator(GoogleCloudBaseOperator, _BigQueryInsertJobOperatorOpenLineageMixin): diff --git a/providers/tests/google/cloud/operators/test_bigquery.py b/providers/tests/google/cloud/operators/test_bigquery.py index 2f7d5ef57df05..1fc8ab93cdcf9 100644 --- a/providers/tests/google/cloud/operators/test_bigquery.py +++ b/providers/tests/google/cloud/operators/test_bigquery.py @@ -41,6 +41,9 @@ ExternalQueryRunFacet, Identifier, InputDataset, + LifecycleStateChange, + LifecycleStateChangeDatasetFacet, + PreviousIdentifier, SchemaDatasetFacet, SchemaDatasetFacetFields, SQLJobFacet, @@ -570,6 +573,49 @@ def test_execute(self, mock_hook): project_id=TEST_GCP_PROJECT_ID, ) + @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") + def test_get_openlineage_facets_on_complete(self, mock_hook): + table_resource = { + "tableReference": { + "projectId": TEST_GCP_PROJECT_ID, + "datasetId": TEST_DATASET, + "tableId": TEST_TABLE_ID, + }, + "description": "Table description.", + "schema": { + "fields": [ + {"name": "field1", "type": "STRING", "description": "field1 description"}, + {"name": "field2", "type": "INTEGER"}, + ] + }, + } + mock_hook.return_value.update_table.return_value = table_resource + operator = BigQueryUpdateTableOperator( + table_resource={}, + task_id=TASK_ID, + dataset_id=TEST_DATASET, + table_id=TEST_TABLE_ID, + project_id=TEST_GCP_PROJECT_ID, + ) + + operator.execute(context=MagicMock()) + result = operator.get_openlineage_facets_on_complete(None) + assert not result.run_facets + assert not result.job_facets + assert not result.inputs + assert len(result.outputs) == 1 + assert result.outputs[0].namespace == BIGQUERY_NAMESPACE + assert result.outputs[0].name == f"{TEST_GCP_PROJECT_ID}.{TEST_DATASET}.{TEST_TABLE_ID}" + assert result.outputs[0].facets == { + "schema": SchemaDatasetFacet( + fields=[ + SchemaDatasetFacetFields(name="field1", type="STRING", description="field1 description"), + SchemaDatasetFacetFields(name="field2", type="INTEGER"), + ] + ), + "documentation": DocumentationDatasetFacet(description="Table description."), + } + class TestBigQueryUpdateTableSchemaOperator: @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") @@ -606,6 +652,59 @@ def test_execute(self, mock_hook): project_id=TEST_GCP_PROJECT_ID, ) + @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") + def test_get_openlineage_facets_on_complete(self, mock_hook): + table_resource = { + "tableReference": { + "projectId": TEST_GCP_PROJECT_ID, + "datasetId": TEST_DATASET, + "tableId": TEST_TABLE_ID, + }, + "description": "Table description.", + "schema": { + "fields": [ + {"name": "field1", "type": "STRING", "description": "field1 description"}, + {"name": "field2", "type": "INTEGER"}, + ] + }, + } + mock_hook.return_value.update_table_schema.return_value = table_resource + schema_field_updates = [ + { + "name": "emp_name", + "description": "Name of employee", + } + ] + + operator = BigQueryUpdateTableSchemaOperator( + schema_fields_updates=schema_field_updates, + include_policy_tags=False, + task_id=TASK_ID, + dataset_id=TEST_DATASET, + table_id=TEST_TABLE_ID, + project_id=TEST_GCP_PROJECT_ID, + location=TEST_DATASET_LOCATION, + impersonation_chain=["service-account@myproject.iam.gserviceaccount.com"], + ) + operator.execute(context=MagicMock()) + + result = operator.get_openlineage_facets_on_complete(None) + assert not result.run_facets + assert not result.job_facets + assert not result.inputs + assert len(result.outputs) == 1 + assert result.outputs[0].namespace == BIGQUERY_NAMESPACE + assert result.outputs[0].name == f"{TEST_GCP_PROJECT_ID}.{TEST_DATASET}.{TEST_TABLE_ID}" + assert result.outputs[0].facets == { + "schema": SchemaDatasetFacet( + fields=[ + SchemaDatasetFacetFields(name="field1", type="STRING", description="field1 description"), + SchemaDatasetFacetFields(name="field2", type="INTEGER"), + ] + ), + "documentation": DocumentationDatasetFacet(description="Table description."), + } + class TestBigQueryUpdateDatasetOperator: @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") @@ -889,6 +988,33 @@ def test_execute(self, mock_hook): table_id=deletion_dataset_table, not_found_ok=ignore_if_missing ) + @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") + def test_get_openlineage_facets_on_complete(self, mock_hook): + mock_hook.return_value.project_id = "default_project_id" + operator = BigQueryDeleteTableOperator( + task_id=TASK_ID, + deletion_dataset_table=f"{TEST_DATASET}.{TEST_TABLE_ID}", + ignore_if_missing=True, + ) + + operator.execute(None) + result = operator.get_openlineage_facets_on_complete(None) + assert not result.run_facets + assert not result.job_facets + assert not result.outputs + assert len(result.inputs) == 1 + assert result.inputs[0].namespace == BIGQUERY_NAMESPACE + assert result.inputs[0].name == f"default_project_id.{TEST_DATASET}.{TEST_TABLE_ID}" + assert result.inputs[0].facets == { + "lifecycleStateChange": LifecycleStateChangeDatasetFacet( + lifecycleStateChange=LifecycleStateChange.DROP.value, + previousIdentifier=PreviousIdentifier( + namespace=BIGQUERY_NAMESPACE, + name=f"default_project_id.{TEST_DATASET}.{TEST_TABLE_ID}", + ), + ) + } + class TestBigQueryGetDatasetTablesOperator: @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") @@ -941,6 +1067,48 @@ def test_execute(self, mock_hook): dataset_id=TEST_DATASET, project_id=TEST_GCP_PROJECT_ID, table_resource=TEST_TABLE_RESOURCES ) + @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") + def test_get_openlineage_facets_on_complete(self, mock_hook): + table_resource = { + "tableReference": { + "projectId": TEST_GCP_PROJECT_ID, + "datasetId": TEST_DATASET, + "tableId": TEST_TABLE_ID, + }, + "description": "Table description.", + "schema": { + "fields": [ + {"name": "field1", "type": "STRING", "description": "field1 description"}, + {"name": "field2", "type": "INTEGER"}, + ] + }, + } + mock_hook.return_value.run_table_upsert.return_value = table_resource + operator = BigQueryUpsertTableOperator( + task_id=TASK_ID, + dataset_id=TEST_DATASET, + table_resource=TEST_TABLE_RESOURCES, + project_id=TEST_GCP_PROJECT_ID, + ) + operator.execute(context=MagicMock()) + + result = operator.get_openlineage_facets_on_complete(None) + assert not result.run_facets + assert not result.job_facets + assert not result.inputs + assert len(result.outputs) == 1 + assert result.outputs[0].namespace == BIGQUERY_NAMESPACE + assert result.outputs[0].name == f"{TEST_GCP_PROJECT_ID}.{TEST_DATASET}.{TEST_TABLE_ID}" + assert result.outputs[0].facets == { + "schema": SchemaDatasetFacet( + fields=[ + SchemaDatasetFacetFields(name="field1", type="STRING", description="field1 description"), + SchemaDatasetFacetFields(name="field2", type="INTEGER"), + ] + ), + "documentation": DocumentationDatasetFacet(description="Table description."), + } + class TestBigQueryInsertJobOperator: @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")