diff --git a/pkg/pipeline/lineage.go b/pkg/pipeline/lineage.go index d97bfb6a..dff048b3 100644 --- a/pkg/pipeline/lineage.go +++ b/pkg/pipeline/lineage.go @@ -190,7 +190,10 @@ func (p *LineageExtractor) processLineageColumns(foundPipeline *Pipeline, asset if upstream.Table == asset.Name { continue } - upstreamAsset := foundPipeline.GetAssetByName(upstream.Table) + + tableSpec := strings.Split(upstream.Table, ".") + table_name := tableSpec[len(strings.Split(upstream.Table, "."))-1] + upstreamAsset := foundPipeline.GetAssetByName(table_name) if upstreamAsset == nil && upstream.Table != "" { if err := p.addColumnToAsset(asset, lineageCol.Name, nil, &Column{ Name: upstream.Column, @@ -199,7 +202,7 @@ func (p *LineageExtractor) processLineageColumns(foundPipeline *Pipeline, asset Upstreams: []*UpstreamColumn{ { Column: upstream.Column, - Table: strings.ToLower(upstream.Table), + Table: strings.ToLower(table_name), }, }, }); err != nil { diff --git a/pythonsrc/parser/main.py b/pythonsrc/parser/main.py index 35ad4fb1..82a3cc18 100644 --- a/pythonsrc/parser/main.py +++ b/pythonsrc/parser/main.py @@ -44,8 +44,6 @@ def extract_non_selected_columns(parsed: exp.Select) -> list[Column]: result = list(set(cols)) result.sort(key=lambda x: x.name + x.table) - for c in result: - c = Column(name=c.name, table=extract_table_name(c.table)) return result @@ -117,7 +115,7 @@ def get_column_lineage(query: str, schema: dict, dialect: str): nested_schema = schema_dict_to_schema_object(schema) try: optimized = optimize(parsed, nested_schema, dialect=dialect) - except Exception as e: + except Exception: # try again without dialect, this solves some issues, e.g. https://github.com/tobymao/sqlglot/issues/4538 optimized = optimize(parsed, nested_schema) except Exception as e: @@ -157,8 +155,6 @@ def get_column_lineage(query: str, schema: dict, dialect: str): cl = [dict(t) for t in {tuple(d.items()) for d in cl}] cl.sort(key=lambda x: x["table"]) - for c in cl: - c["table"] = extract_table_name(c["table"]) result.append({"name": col["name"], "upstream": cl, "type": col["type"]}) result.sort(key=lambda x: x["name"]) @@ -188,13 +184,6 @@ def find_leaf_nodes(node: Node, leaf_nodes): find_leaf_nodes(child, leaf_nodes) -def extract_table_name(table: str) -> str: - if not table: - return "" - parts = table.split(".") - return parts[-1].lower() - - def merge_parts(table: exp.Table) -> str: return ".".join( part.name for part in table.parts if isinstance(part, exp.Identifier)