Skip to content

Commit

Permalink
feat: added ast function to extract all tables
Browse files Browse the repository at this point in the history
Signed-off-by: Harshit Gangal <[email protected]>
  • Loading branch information
harshit-gangal committed Aug 26, 2024
1 parent 9710dc0 commit 0b1d91e
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 0 deletions.
21 changes: 21 additions & 0 deletions go/vt/sqlparser/ast_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -2818,3 +2818,24 @@ func (lock Lock) GetHighestOrderLock(newLock Lock) Lock {
func Clone[K SQLNode](x K) K {
return CloneSQLNode(x).(K)
}

// ExtractAllTables returns all the table names in the SQLNode
func ExtractAllTables(node SQLNode) []string {
var tables []string
tableMap := make(map[string]any)
_ = Walk(func(node SQLNode) (kontinue bool, err error) {
switch node := node.(type) {
case *AliasedTableExpr:
if tblName, ok := node.Expr.(TableName); ok {
name := String(tblName)
if _, exists := tableMap[name]; !exists {
tableMap[name] = nil
tables = append(tables, name)
}
return false, nil
}
}
return true, nil
}, node)
return tables
}
47 changes: 47 additions & 0 deletions go/vt/sqlparser/ast_funcs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,50 @@ func TestColumns_Indexes(t *testing.T) {
})
}
}

// TestExtractTables verifies the functionality of extracting all the tables from the SQLNode.
func TestExtractTables(t *testing.T) {
tcases := []struct {
sql string
expected []string
}{{
sql: "select 1 from a",
expected: []string{"a"},
}, {
sql: "select 1 from a, b",
expected: []string{"a", "b"},
}, {
sql: "select 1 from a join b on a.id = b.id",
expected: []string{"a", "b"},
}, {
sql: "select 1 from a join b on a.id = b.id join c on b.id = c.id",
expected: []string{"a", "b", "c"},
}, {
sql: "select 1 from a join (select id from b) as c on a.id = c.id",
expected: []string{"a", "b"},
}, {
sql: "(select 1 from a) union (select 1 from b)",
expected: []string{"a", "b"},
}, {
sql: "select 1 from a where exists (select 1 from (select id from c) b where a.id = b.id)",
expected: []string{"a", "c"},
}, {
sql: "select 1 from k.a join k.b on a.id = b.id",
expected: []string{"k.a", "k.b"},
}, {
sql: "select 1 from k.a join l.a on k.a.id = l.a.id",
expected: []string{"k.a", "l.a"},
}, {
sql: "select 1 from a join (select id from a) as c on a.id = c.id",
expected: []string{"a"},
}}
parser := NewTestParser()
for _, tcase := range tcases {
t.Run(tcase.sql, func(t *testing.T) {
stmt, err := parser.Parse(tcase.sql)
require.NoError(t, err)
tables := ExtractAllTables(stmt)
require.Equal(t, tcase.expected, tables)
})
}
}

0 comments on commit 0b1d91e

Please sign in to comment.