From ce16963e9d69849309aa0a7cf978ed85ab741439 Mon Sep 17 00:00:00 2001 From: Kacper Muda Date: Fri, 17 Nov 2023 15:52:40 +0100 Subject: [PATCH] Add OpenLineage support to BigQueryToGCSOperator (#35660) --- .../google/cloud/transfers/bigquery_to_gcs.py | 75 ++++- .../google/cloud/utils/openlineage.py | 80 +++++ docs/spelling_wordlist.txt | 1 + .../cloud/transfers/test_bigquery_to_gcs.py | 274 ++++++++++++++++++ .../google/cloud/utils/test_openlineage.py | 142 +++++++++ 5 files changed, 571 insertions(+), 1 deletion(-) create mode 100644 airflow/providers/google/cloud/utils/openlineage.py create mode 100644 tests/providers/google/cloud/utils/test_openlineage.py diff --git a/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py b/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py index a01c564cc3b7b..58456b10f9816 100644 --- a/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py @@ -29,6 +29,7 @@ from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook, BigQueryJob from airflow.providers.google.cloud.links.bigquery import BigQueryTableLink from airflow.providers.google.cloud.triggers.bigquery import BigQueryInsertJobTrigger +from airflow.utils.helpers import merge_dicts if TYPE_CHECKING: from google.api_core.retry import Retry @@ -139,6 +140,8 @@ def __init__( self.hook: BigQueryHook | None = None self.deferrable = deferrable + self._job_id: str = "" + @staticmethod def _handle_job_error(job: BigQueryJob | UnknownJob) -> None: if job.error_result: @@ -240,6 +243,7 @@ def execute(self, context: Context): f"Or, if you want to reattach in this scenario add {job.state} to `reattach_states`" ) + self._job_id = job.job_id conf = job.to_api_repr()["configuration"]["extract"]["sourceTable"] dataset_id, project_id, table_id = conf["datasetId"], conf["projectId"], conf["tableId"] BigQueryTableLink.persist( @@ -255,7 +259,7 @@ def execute(self, context: Context): timeout=self.execution_timeout, trigger=BigQueryInsertJobTrigger( conn_id=self.gcp_conn_id, - job_id=job_id, + job_id=self._job_id, project_id=self.project_id or self.hook.project_id, ), method_name="execute_complete", @@ -276,3 +280,72 @@ def execute_complete(self, context: Context, event: dict[str, Any]): self.task_id, event["message"], ) + + def get_openlineage_facets_on_complete(self, task_instance): + """Implementing on_complete as we will include final BQ job id.""" + from pathlib import Path + + from openlineage.client.facet import ( + ExternalQueryRunFacet, + SymlinksDatasetFacet, + SymlinksDatasetFacetIdentifiers, + ) + from openlineage.client.run import Dataset + + from airflow.providers.google.cloud.hooks.gcs import _parse_gcs_url + from airflow.providers.google.cloud.utils.openlineage import ( + get_facets_from_bq_table, + get_identity_column_lineage_facet, + ) + from airflow.providers.openlineage.extractors import OperatorLineage + + table_object = self.hook.get_client(self.hook.project_id).get_table(self.source_project_dataset_table) + + input_dataset = Dataset( + namespace="bigquery", + name=str(table_object.reference), + facets=get_facets_from_bq_table(table_object), + ) + + output_dataset_facets = { + "schema": input_dataset.facets["schema"], + "columnLineage": get_identity_column_lineage_facet( + field_names=[field.name for field in table_object.schema], input_datasets=[input_dataset] + ), + } + output_datasets = [] + for uri in sorted(self.destination_cloud_storage_uris): + bucket, blob = _parse_gcs_url(uri) + additional_facets = {} + + if "*" in blob: + # If wildcard ("*") is used in gcs path, we want the name of dataset to be directory name, + # but we create a symlink to the full object path with wildcard. + additional_facets = { + "symlink": SymlinksDatasetFacet( + identifiers=[ + SymlinksDatasetFacetIdentifiers( + namespace=f"gs://{bucket}", name=blob, type="file" + ) + ] + ), + } + blob = Path(blob).parent.as_posix() + if blob == ".": + # blob path does not have leading slash, but we need root dataset name to be "/" + blob = "/" + + dataset = Dataset( + namespace=f"gs://{bucket}", + name=blob, + facets=merge_dicts(output_dataset_facets, additional_facets), + ) + output_datasets.append(dataset) + + run_facets = {} + if self._job_id: + run_facets = { + "externalQuery": ExternalQueryRunFacet(externalQueryId=self._job_id, source="bigquery"), + } + + return OperatorLineage(inputs=[input_dataset], outputs=output_datasets, run_facets=run_facets) diff --git a/airflow/providers/google/cloud/utils/openlineage.py b/airflow/providers/google/cloud/utils/openlineage.py new file mode 100644 index 0000000000000..3e96fffe5af33 --- /dev/null +++ b/airflow/providers/google/cloud/utils/openlineage.py @@ -0,0 +1,80 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains code related to OpenLineage and lineage extraction.""" +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from openlineage.client.facet import ( + ColumnLineageDatasetFacet, + ColumnLineageDatasetFacetFieldsAdditional, + ColumnLineageDatasetFacetFieldsAdditionalInputFields, + DocumentationDatasetFacet, + SchemaDatasetFacet, + SchemaField, +) + +if TYPE_CHECKING: + from google.cloud.bigquery.table import Table + from openlineage.client.run import Dataset + + +def get_facets_from_bq_table(table: Table) -> dict[Any, Any]: + """Get facets from BigQuery table object.""" + facets = { + "schema": SchemaDatasetFacet( + fields=[ + SchemaField(name=field.name, type=field.field_type, description=field.description) + for field in table.schema + ] + ), + "documentation": DocumentationDatasetFacet(description=table.description or ""), + } + + return facets + + +def get_identity_column_lineage_facet( + field_names: list[str], + input_datasets: list[Dataset], +) -> ColumnLineageDatasetFacet: + """ + Get column lineage facet. + + Simple lineage will be created, where each source column corresponds to single destination column + in each input dataset and there are no transformations made. + """ + if field_names and not input_datasets: + raise ValueError("When providing `field_names` You must provide at least one `input_dataset`.") + + column_lineage_facet = ColumnLineageDatasetFacet( + fields={ + field: ColumnLineageDatasetFacetFieldsAdditional( + inputFields=[ + ColumnLineageDatasetFacetFieldsAdditionalInputFields( + namespace=dataset.namespace, name=dataset.name, field=field + ) + for dataset in input_datasets + ], + transformationType="IDENTITY", + transformationDescription="identical", + ) + for field in field_names + } + ) + return column_lineage_facet diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index c56f81ebaf75a..65bbc8cb6fce2 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -272,6 +272,7 @@ codepoints Colour colour colours +ColumnLineageDatasetFacet CommandType comparator compat diff --git a/tests/providers/google/cloud/transfers/test_bigquery_to_gcs.py b/tests/providers/google/cloud/transfers/test_bigquery_to_gcs.py index b7bf8bef62d8c..5dd32892abb43 100644 --- a/tests/providers/google/cloud/transfers/test_bigquery_to_gcs.py +++ b/tests/providers/google/cloud/transfers/test_bigquery_to_gcs.py @@ -22,6 +22,19 @@ import pytest from google.cloud.bigquery.retry import DEFAULT_RETRY +from google.cloud.bigquery.table import Table +from openlineage.client.facet import ( + ColumnLineageDatasetFacet, + ColumnLineageDatasetFacetFieldsAdditional, + ColumnLineageDatasetFacetFieldsAdditionalInputFields, + DocumentationDatasetFacet, + ExternalQueryRunFacet, + SchemaDatasetFacet, + SchemaField, + SymlinksDatasetFacet, + SymlinksDatasetFacetIdentifiers, +) +from openlineage.client.run import Dataset from airflow.exceptions import TaskDeferred from airflow.providers.google.cloud.transfers.bigquery_to_gcs import BigQueryToGCSOperator @@ -32,6 +45,25 @@ TEST_TABLE_ID = "test-table-id" PROJECT_ID = "test-project-id" JOB_PROJECT_ID = "job-project-id" +TEST_BUCKET = "test-bucket" +TEST_FOLDER = "test-folder" +TEST_OBJECT_NO_WILDCARD = "file.extension" +TEST_OBJECT_WILDCARD = "file_*.extension" +TEST_TABLE_API_REPR = { + "tableReference": {"projectId": PROJECT_ID, "datasetId": TEST_DATASET, "tableId": TEST_TABLE_ID}, + "description": "Table description.", + "schema": { + "fields": [ + {"name": "field1", "type": "STRING", "description": "field1 description"}, + {"name": "field2", "type": "INTEGER"}, + ] + }, +} +TEST_TABLE: Table = Table.from_api_repr(TEST_TABLE_API_REPR) +TEST_EMPTY_TABLE_API_REPR = { + "tableReference": {"projectId": PROJECT_ID, "datasetId": TEST_DATASET, "tableId": TEST_TABLE_ID} +} +TEST_EMPTY_TABLE: Table = Table.from_api_repr(TEST_EMPTY_TABLE_API_REPR) class TestBigQueryToGCSOperator: @@ -154,3 +186,245 @@ def test_execute_deferrable_mode(self, mock_hook): retry=DEFAULT_RETRY, nowait=True, ) + + @pytest.mark.parametrize( + ("gcs_uri", "expected_dataset_name"), + ( + ( + f"gs://{TEST_BUCKET}/{TEST_FOLDER}/{TEST_OBJECT_NO_WILDCARD}", + f"{TEST_FOLDER}/{TEST_OBJECT_NO_WILDCARD}", + ), + (f"gs://{TEST_BUCKET}/{TEST_OBJECT_NO_WILDCARD}", TEST_OBJECT_NO_WILDCARD), + (f"gs://{TEST_BUCKET}/{TEST_FOLDER}/{TEST_OBJECT_WILDCARD}", TEST_FOLDER), + (f"gs://{TEST_BUCKET}/{TEST_OBJECT_WILDCARD}", "/"), + (f"gs://{TEST_BUCKET}/{TEST_FOLDER}/*", TEST_FOLDER), + ), + ) + @mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_gcs.BigQueryHook") + def test_get_openlineage_facets_on_complete_gcs_dataset_name( + self, mock_hook, gcs_uri, expected_dataset_name + ): + operator = BigQueryToGCSOperator( + project_id=JOB_PROJECT_ID, + task_id=TASK_ID, + source_project_dataset_table=f"{PROJECT_ID}.{TEST_DATASET}.{TEST_TABLE_ID}", + destination_cloud_storage_uris=[gcs_uri], + ) + + mock_hook.return_value.split_tablename.return_value = (PROJECT_ID, TEST_DATASET, TEST_TABLE_ID) + operator.execute(context=mock.MagicMock()) + + lineage = operator.get_openlineage_facets_on_complete(None) + assert len(lineage.outputs) == 1 + assert lineage.outputs[0].name == expected_dataset_name + + @mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_gcs.BigQueryHook") + def test_get_openlineage_facets_on_complete_gcs_multiple_uris(self, mock_hook): + operator = BigQueryToGCSOperator( + project_id=JOB_PROJECT_ID, + task_id=TASK_ID, + source_project_dataset_table=f"{PROJECT_ID}.{TEST_DATASET}.{TEST_TABLE_ID}", + destination_cloud_storage_uris=[ + f"gs://{TEST_BUCKET}1/{TEST_FOLDER}1/{TEST_OBJECT_NO_WILDCARD}", + f"gs://{TEST_BUCKET}2/{TEST_FOLDER}2/{TEST_OBJECT_WILDCARD}", + f"gs://{TEST_BUCKET}3/{TEST_OBJECT_NO_WILDCARD}", + f"gs://{TEST_BUCKET}4/{TEST_OBJECT_WILDCARD}", + ], + ) + + mock_hook.return_value.split_tablename.return_value = (PROJECT_ID, TEST_DATASET, TEST_TABLE_ID) + operator.execute(context=mock.MagicMock()) + + lineage = operator.get_openlineage_facets_on_complete(None) + assert len(lineage.outputs) == 4 + assert lineage.outputs[0].name == f"{TEST_FOLDER}1/{TEST_OBJECT_NO_WILDCARD}" + assert lineage.outputs[1].name == f"{TEST_FOLDER}2" + assert lineage.outputs[2].name == TEST_OBJECT_NO_WILDCARD + assert lineage.outputs[3].name == "/" + + @mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_gcs.BigQueryHook") + def test_get_openlineage_facets_on_complete_bq_dataset(self, mock_hook): + source_project_dataset_table = f"{PROJECT_ID}.{TEST_DATASET}.{TEST_TABLE_ID}" + + expected_input_dataset_facets = { + "schema": SchemaDatasetFacet( + fields=[ + SchemaField(name="field1", type="STRING", description="field1 description"), + SchemaField(name="field2", type="INTEGER"), + ] + ), + "documentation": DocumentationDatasetFacet(description="Table description."), + } + + mock_hook.return_value.split_tablename.return_value = (PROJECT_ID, TEST_DATASET, TEST_TABLE_ID) + mock_hook.return_value.get_client.return_value.get_table.return_value = TEST_TABLE + + operator = BigQueryToGCSOperator( + project_id=JOB_PROJECT_ID, + task_id=TASK_ID, + source_project_dataset_table=source_project_dataset_table, + destination_cloud_storage_uris=["gs://bucket/file"], + ) + operator.execute(context=mock.MagicMock()) + + lineage = operator.get_openlineage_facets_on_complete(None) + assert len(lineage.inputs) == 1 + assert lineage.inputs[0] == Dataset( + namespace="bigquery", + name=source_project_dataset_table, + facets=expected_input_dataset_facets, + ) + + @mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_gcs.BigQueryHook") + def test_get_openlineage_facets_on_complete_bq_dataset_empty_table(self, mock_hook): + source_project_dataset_table = f"{PROJECT_ID}.{TEST_DATASET}.{TEST_TABLE_ID}" + + expected_input_dataset_facets = { + "schema": SchemaDatasetFacet(fields=[]), + "documentation": DocumentationDatasetFacet(description=""), + } + + mock_hook.return_value.split_tablename.return_value = (PROJECT_ID, TEST_DATASET, TEST_TABLE_ID) + mock_hook.return_value.get_client.return_value.get_table.return_value = TEST_EMPTY_TABLE + + operator = BigQueryToGCSOperator( + project_id=JOB_PROJECT_ID, + task_id=TASK_ID, + source_project_dataset_table=source_project_dataset_table, + destination_cloud_storage_uris=["gs://bucket/file"], + ) + operator.execute(context=mock.MagicMock()) + + lineage = operator.get_openlineage_facets_on_complete(None) + assert len(lineage.inputs) == 1 + assert lineage.inputs[0] == Dataset( + namespace="bigquery", + name=source_project_dataset_table, + facets=expected_input_dataset_facets, + ) + + @mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_gcs.BigQueryHook") + def test_get_openlineage_facets_on_complete_gcs_no_wildcard_empty_table(self, mock_hook): + source_project_dataset_table = f"{PROJECT_ID}.{TEST_DATASET}.{TEST_TABLE_ID}" + destination_cloud_storage_uris = [f"gs://{TEST_BUCKET}/{TEST_FOLDER}/{TEST_OBJECT_NO_WILDCARD}"] + real_job_id = "123456_hash" + bq_namespace = "bigquery" + + expected_input_facets = { + "schema": SchemaDatasetFacet(fields=[]), + "documentation": DocumentationDatasetFacet(description=""), + } + + expected_output_facets = { + "schema": SchemaDatasetFacet(fields=[]), + "columnLineage": ColumnLineageDatasetFacet(fields={}), + } + + mock_hook.return_value.split_tablename.return_value = (PROJECT_ID, TEST_DATASET, TEST_TABLE_ID) + mock_hook.return_value.insert_job.return_value = MagicMock(job_id=real_job_id, error_result=False) + mock_hook.return_value.get_client.return_value.get_table.return_value = TEST_EMPTY_TABLE + + operator = BigQueryToGCSOperator( + project_id=JOB_PROJECT_ID, + task_id=TASK_ID, + source_project_dataset_table=source_project_dataset_table, + destination_cloud_storage_uris=destination_cloud_storage_uris, + ) + + operator.execute(context=mock.MagicMock()) + + lineage = operator.get_openlineage_facets_on_complete(None) + assert len(lineage.inputs) == 1 + assert len(lineage.outputs) == 1 + assert lineage.inputs[0] == Dataset( + namespace=bq_namespace, name=source_project_dataset_table, facets=expected_input_facets + ) + assert lineage.outputs[0] == Dataset( + namespace=f"gs://{TEST_BUCKET}", + name=f"{TEST_FOLDER}/{TEST_OBJECT_NO_WILDCARD}", + facets=expected_output_facets, + ) + assert lineage.run_facets == { + "externalQuery": ExternalQueryRunFacet(externalQueryId=real_job_id, source=bq_namespace) + } + assert lineage.job_facets == {} + + @mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_gcs.BigQueryHook") + def test_get_openlineage_facets_on_complete_gcs_wildcard_full_table(self, mock_hook): + source_project_dataset_table = f"{PROJECT_ID}.{TEST_DATASET}.{TEST_TABLE_ID}" + destination_cloud_storage_uris = [f"gs://{TEST_BUCKET}/{TEST_FOLDER}/{TEST_OBJECT_WILDCARD}"] + real_job_id = "123456_hash" + bq_namespace = "bigquery" + + schema_facet = SchemaDatasetFacet( + fields=[ + SchemaField(name="field1", type="STRING", description="field1 description"), + SchemaField(name="field2", type="INTEGER"), + ] + ) + expected_input_facets = { + "schema": schema_facet, + "documentation": DocumentationDatasetFacet(description="Table description."), + } + + expected_output_facets = { + "schema": schema_facet, + "columnLineage": ColumnLineageDatasetFacet( + fields={ + "field1": ColumnLineageDatasetFacetFieldsAdditional( + inputFields=[ + ColumnLineageDatasetFacetFieldsAdditionalInputFields( + namespace=bq_namespace, name=source_project_dataset_table, field="field1" + ) + ], + transformationType="IDENTITY", + transformationDescription="identical", + ), + "field2": ColumnLineageDatasetFacetFieldsAdditional( + inputFields=[ + ColumnLineageDatasetFacetFieldsAdditionalInputFields( + namespace=bq_namespace, name=source_project_dataset_table, field="field2" + ) + ], + transformationType="IDENTITY", + transformationDescription="identical", + ), + } + ), + "symlink": SymlinksDatasetFacet( + identifiers=[ + SymlinksDatasetFacetIdentifiers( + namespace=f"gs://{TEST_BUCKET}", + name=f"{TEST_FOLDER}/{TEST_OBJECT_WILDCARD}", + type="file", + ) + ] + ), + } + + mock_hook.return_value.split_tablename.return_value = (PROJECT_ID, TEST_DATASET, TEST_TABLE_ID) + mock_hook.return_value.insert_job.return_value = MagicMock(job_id=real_job_id, error_result=False) + mock_hook.return_value.get_client.return_value.get_table.return_value = TEST_TABLE + + operator = BigQueryToGCSOperator( + project_id=JOB_PROJECT_ID, + task_id=TASK_ID, + source_project_dataset_table=source_project_dataset_table, + destination_cloud_storage_uris=destination_cloud_storage_uris, + ) + + operator.execute(context=mock.MagicMock()) + + lineage = operator.get_openlineage_facets_on_complete(None) + assert len(lineage.inputs) == 1 + assert len(lineage.outputs) == 1 + assert lineage.inputs[0] == Dataset( + namespace=bq_namespace, name=source_project_dataset_table, facets=expected_input_facets + ) + assert lineage.outputs[0] == Dataset( + namespace=f"gs://{TEST_BUCKET}", name=TEST_FOLDER, facets=expected_output_facets + ) + assert lineage.run_facets == { + "externalQuery": ExternalQueryRunFacet(externalQueryId=real_job_id, source=bq_namespace) + } + assert lineage.job_facets == {} diff --git a/tests/providers/google/cloud/utils/test_openlineage.py b/tests/providers/google/cloud/utils/test_openlineage.py new file mode 100644 index 0000000000000..608007fa4ba0d --- /dev/null +++ b/tests/providers/google/cloud/utils/test_openlineage.py @@ -0,0 +1,142 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest +from google.cloud.bigquery.table import Table +from openlineage.client.facet import ( + ColumnLineageDatasetFacet, + ColumnLineageDatasetFacetFieldsAdditional, + ColumnLineageDatasetFacetFieldsAdditionalInputFields, + DocumentationDatasetFacet, + SchemaDatasetFacet, + SchemaField, +) +from openlineage.client.run import Dataset + +from airflow.providers.google.cloud.utils import openlineage + +TEST_DATASET = "test-dataset" +TEST_TABLE_ID = "test-table-id" +TEST_PROJECT_ID = "test-project-id" +TEST_TABLE_API_REPR = { + "tableReference": {"projectId": TEST_PROJECT_ID, "datasetId": TEST_DATASET, "tableId": TEST_TABLE_ID}, + "description": "Table description.", + "schema": { + "fields": [ + {"name": "field1", "type": "STRING", "description": "field1 description"}, + {"name": "field2", "type": "INTEGER"}, + ] + }, +} +TEST_TABLE: Table = Table.from_api_repr(TEST_TABLE_API_REPR) +TEST_EMPTY_TABLE_API_REPR = { + "tableReference": {"projectId": TEST_PROJECT_ID, "datasetId": TEST_DATASET, "tableId": TEST_TABLE_ID} +} +TEST_EMPTY_TABLE: Table = Table.from_api_repr(TEST_EMPTY_TABLE_API_REPR) + + +def test_get_facets_from_bq_table(): + expected_facets = { + "schema": SchemaDatasetFacet( + fields=[ + SchemaField(name="field1", type="STRING", description="field1 description"), + SchemaField(name="field2", type="INTEGER"), + ] + ), + "documentation": DocumentationDatasetFacet(description="Table description."), + } + result = openlineage.get_facets_from_bq_table(TEST_TABLE) + assert result == expected_facets + + +def test_get_facets_from_empty_bq_table(): + expected_facets = { + "schema": SchemaDatasetFacet(fields=[]), + "documentation": DocumentationDatasetFacet(description=""), + } + result = openlineage.get_facets_from_bq_table(TEST_EMPTY_TABLE) + assert result == expected_facets + + +def test_get_identity_column_lineage_facet_multiple_input_datasets(): + field_names = ["field1", "field2"] + input_datasets = [ + Dataset(namespace="gs://first_bucket", name="dir1"), + Dataset(namespace="gs://second_bucket", name="dir2"), + ] + expected_facet = ColumnLineageDatasetFacet( + fields={ + "field1": ColumnLineageDatasetFacetFieldsAdditional( + inputFields=[ + ColumnLineageDatasetFacetFieldsAdditionalInputFields( + namespace="gs://first_bucket", + name="dir1", + field="field1", + ), + ColumnLineageDatasetFacetFieldsAdditionalInputFields( + namespace="gs://second_bucket", + name="dir2", + field="field1", + ), + ], + transformationType="IDENTITY", + transformationDescription="identical", + ), + "field2": ColumnLineageDatasetFacetFieldsAdditional( + inputFields=[ + ColumnLineageDatasetFacetFieldsAdditionalInputFields( + namespace="gs://first_bucket", + name="dir1", + field="field2", + ), + ColumnLineageDatasetFacetFieldsAdditionalInputFields( + namespace="gs://second_bucket", + name="dir2", + field="field2", + ), + ], + transformationType="IDENTITY", + transformationDescription="identical", + ), + } + ) + result = openlineage.get_identity_column_lineage_facet( + field_names=field_names, input_datasets=input_datasets + ) + assert result == expected_facet + + +def test_get_identity_column_lineage_facet_no_field_names(): + field_names = [] + input_datasets = [ + Dataset(namespace="gs://first_bucket", name="dir1"), + Dataset(namespace="gs://second_bucket", name="dir2"), + ] + expected_facet = ColumnLineageDatasetFacet(fields={}) + result = openlineage.get_identity_column_lineage_facet( + field_names=field_names, input_datasets=input_datasets + ) + assert result == expected_facet + + +def test_get_identity_column_lineage_facet_no_input_datasets(): + field_names = ["field1", "field2"] + input_datasets = [] + + with pytest.raises(ValueError): + openlineage.get_identity_column_lineage_facet(field_names=field_names, input_datasets=input_datasets)