Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Experiment] Basic column lineage with sqlglot #2065

Draft
wants to merge 5 commits into
base: devel
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions dlt/common/schema/lineage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# NOTE: this should live in libs/sqlglot

from typing import List, Dict, Optional, TYPE_CHECKING, Tuple, Union, Any

import sqlglot
from sqlglot import Schema, maybe_parse
from sqlglot.optimizer import build_scope, qualify
from sqlglot.lineage import lineage, Scope


def _build_scope(sql: str, dialect: str, schema: Dict | Schema, **kwargs) -> Scope:
expression = maybe_parse(sql, dialect=dialect)

expression = qualify.qualify(
expression,
dialect=dialect,
schema=schema,
**{"validate_qualify_columns": False, "identify": False, **kwargs}, # type: ignore
)

return build_scope(expression)


def get_result_column_names_from_select(
sql: str,
schema: Optional[Union[Dict[str, Any], Schema]] = None,
dialect: Optional[str] = None,
scope: Optional[Scope] = None,
**kwargs
) -> List[str]:
"""given an sql select statement, and a schema, return the list of column names in the result"""
"""NOTE: code mostly taken from lineage.py in sqlglot"""

scope = _build_scope(sql, dialect, schema, **kwargs) if scope is None else scope

if not scope:
raise Exception("Expression does not seem to be a valid select statement")
selected_columns = [select.alias_or_name for select in scope.expression.selects]
return selected_columns


def get_column_origin(
column_name: str, sql: str, schema: Dict | Schema, dialect: str = None
) -> Tuple[str, str]:
"""given a column name, an sql select statement, and a schema, return the origin of the column as a tuple of (table_name, column_name)"""
""" returns None, None if original can't be found"""

lineage_graph = lineage(column=column_name, sql=sql, schema=schema, dialect=dialect)

origin_table_name = None
origin_column_name = None

for node in lineage_graph.walk():
print(type(node.expression))
if type(node.expression) == sqlglot.expressions.Table:
origin_table_name = node.expression.name

if type(node.expression) == sqlglot.expressions.Alias:
# search identifier in the expression chain
identifier = node.expression.this
while type(identifier) != sqlglot.expressions.Identifier and identifier is not None:
identifier = identifier.this
origin_column_name = identifier.name if identifier else None

if not origin_table_name:
return None, None
return origin_table_name, origin_column_name


def get_result_origins(
sql: str, schema: Union[Dict, Schema] = None, dialect: str = None
) -> List[Tuple[str, Tuple[str, str]]]:
"""given a schema and a sql select statement, return a list of tuples of (column_name, (origin_table_name, origin_column_name)) for each column name in the result"""
scope = _build_scope(sql, dialect, schema)
selected_columns = get_result_column_names_from_select(sql, schema, dialect, scope)

# build result
result: List[Tuple[str, Tuple[str, str]]] = []

# star select without schema is not possible, raise?
if not schema and selected_columns[0] == "*":
return []

# get origins
for column_name in selected_columns:
result.append((column_name, get_column_origin(column_name, sql, schema, dialect)))
return result
177 changes: 177 additions & 0 deletions tests/common/schema/test_lineage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
"""Test schema lineage with sqlglot"""
import pytest
from dlt.common.schema.lineage import get_result_column_names_from_select, get_result_origins
from sqlglot.lineage import lineage


# TODO: if we are working with schemas and catalogs, we need to implement this here too
EXAMPLE_SCHEMA = {
"orders": {
"order_id": "INTEGER",
"customer_id": "INTEGER",
"total": "FLOAT",
},
"customers": {
"customer_id": "INTEGER",
"name": "STRING",
},
"items": {
"item_id": "INTEGER",
"name": "STRING",
"price": "FLOAT",
},
"order_items": {
"order_id": "INTEGER",
"item_id": "INTEGER",
"quantity": "INTEGER",
"spot_price": "FLOAT",
},
}


def test_result_column_names_from_select():
# simple select, no schema
sql = "SELECT total FROM orders"
column_names = get_result_column_names_from_select(sql)
assert column_names == ["total"]

# star select, no schema, will not know the column names
sql = "SELECT * FROM orders"
column_names = get_result_column_names_from_select(sql)
assert column_names == ["*"]

# simple select, with schema
sql = "SELECT total, order_id FROM orders"
column_names = get_result_column_names_from_select(sql, EXAMPLE_SCHEMA)
assert column_names == ["total", "order_id"]

# star select, with schema, will know the column names
sql = "SELECT * FROM orders"
column_names = get_result_column_names_from_select(sql, EXAMPLE_SCHEMA)
assert column_names == ["order_id", "customer_id", "total"]

# select unknown column (works, could be not known column)
sql = "SELECT unknown FROM orders"
column_names = get_result_column_names_from_select(sql, EXAMPLE_SCHEMA)
assert column_names == ["unknown"]

# star select from joined tables
sql = "SELECT * FROM orders JOIN customers ON orders.customer_id = customers.customer_id"
column_names = get_result_column_names_from_select(sql, EXAMPLE_SCHEMA)
assert column_names == ["order_id", "customer_id", "total", "customer_id", "name"]

# nested star select
sql = """
SELECT *
FROM (SELECT order_id, customer_id FROM orders)
"""
column_names = get_result_column_names_from_select(sql, EXAMPLE_SCHEMA)
assert column_names == ["order_id", "customer_id"]

# select with alias
sql = """
SELECT * FROM (SELECT name AS renamed_name FROM customers)
"""
column_names = get_result_column_names_from_select(sql, EXAMPLE_SCHEMA)
assert column_names == ["renamed_name"]

# aggregate select
sql = """
SELECT SUM(total) as total_sum FROM orders
"""
column_names = get_result_column_names_from_select(sql, EXAMPLE_SCHEMA)
assert column_names == ["total_sum"]

# triple nested subquery with join
sql = """
SELECT * FROM (SELECT * FROM (SELECT order_id, spot_price FROM orders o JOIN order_items i ON o.order_id = i.order_id)) LIMIT 5
"""
column_names = get_result_column_names_from_select(sql, EXAMPLE_SCHEMA)
assert column_names == ["order_id", "spot_price"]

# test simple alias (this fails, as nested subquery it will work)
sql = """
SELECT order_id AS alias_order_id FROM orders
"""
column_names = get_result_column_names_from_select(sql, EXAMPLE_SCHEMA)
assert column_names == ["alias_order_id"]

# group by aggregate
sql = """
SELECT customer_id, SUM(total) as sum FROM orders GROUP BY customer_id
"""
column_names = get_result_column_names_from_select(sql, EXAMPLE_SCHEMA)
assert column_names == ["customer_id", "sum"]


def test_result_origins():
# simple select no schema
sql = "SELECT total FROM orders"
origins = get_result_origins(sql)
assert origins == [("total", ("orders", "total"))]

# star select no schema, will not know what comes from where
sql = "SELECT * FROM orders JOIN customers"
origins = get_result_origins(sql)
assert origins == []

# join no schema, will still know columns from tables if given in statement
sql = "SELECT o.order_id, c.name FROM orders o JOIN customers c"
origins = get_result_origins(sql)
assert origins == [("order_id", ("orders", "order_id")), ("name", ("customers", "name"))]

# simple select with schema
sql = "SELECT total, customer_id FROM orders"
origins = get_result_origins(sql, EXAMPLE_SCHEMA)
assert origins == [("total", ("orders", "total")), ("customer_id", ("orders", "customer_id"))]

# star select
sql = "SELECT * FROM orders"
origins = get_result_origins(sql, EXAMPLE_SCHEMA)
assert origins == [
("order_id", ("orders", "order_id")),
("customer_id", ("orders", "customer_id")),
("total", ("orders", "total")),
]

# join
sql = (
"SELECT o.order_id, c.name FROM orders o JOIN customers c ON o.customer_id = c.customer_id"
)
origins = get_result_origins(sql, EXAMPLE_SCHEMA)
assert origins == [
("order_id", ("orders", "order_id")),
("name", ("customers", "name")),
]

# triple nested subquery with join and rename
sql = """
SELECT * FROM (SELECT * FROM (SELECT o.order_id, i.spot_price as price FROM orders o JOIN order_items i ON o.order_id = i.order_id)) LIMIT 5
"""
column_names = get_result_origins(sql, EXAMPLE_SCHEMA)
assert column_names == [
("order_id", ("orders", "order_id")),
("price", ("order_items", "spot_price")),
]

# group by with aggregate
sql = """
SELECT customer_id, SUM(total) as sum FROM orders GROUP BY customer_id
"""
origins = get_result_origins(sql, EXAMPLE_SCHEMA)
assert origins == [("customer_id", ("orders", "customer_id")), ("sum", ("orders", "total"))]

# select unknown column
sql = "SELECT unknown FROM orders"
origins = get_result_origins(sql, EXAMPLE_SCHEMA)
assert origins == [("unknown", (None, None))]

# concatenate two columns, for now it selects the first column type?
sql = "SELECT (total || ' ' || customer_id) as concat FROM orders"
origins = get_result_origins(sql, EXAMPLE_SCHEMA, dialect="duckdb")
assert origins == [('concat', ('orders', 'total'))]

# where clause
sql = "SELECT * FROM orders WHERE total > 100"
origins = get_result_origins(sql, EXAMPLE_SCHEMA)
assert origins == [('order_id', ('orders', 'order_id')), ('customer_id', ('orders', 'customer_id')), ('total', ('orders', 'total'))]
Loading