From 0b1d91e5b20a66d431cdac3aa9015c2064d86f61 Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Mon, 26 Aug 2024 13:03:47 +0530 Subject: [PATCH] feat: added ast function to extract all tables Signed-off-by: Harshit Gangal --- go/vt/sqlparser/ast_funcs.go | 21 ++++++++++++++ go/vt/sqlparser/ast_funcs_test.go | 47 +++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+) diff --git a/go/vt/sqlparser/ast_funcs.go b/go/vt/sqlparser/ast_funcs.go index ae96fe9c1fe..808e41ab098 100644 --- a/go/vt/sqlparser/ast_funcs.go +++ b/go/vt/sqlparser/ast_funcs.go @@ -2818,3 +2818,24 @@ func (lock Lock) GetHighestOrderLock(newLock Lock) Lock { func Clone[K SQLNode](x K) K { return CloneSQLNode(x).(K) } + +// ExtractAllTables returns all the table names in the SQLNode +func ExtractAllTables(node SQLNode) []string { + var tables []string + tableMap := make(map[string]any) + _ = Walk(func(node SQLNode) (kontinue bool, err error) { + switch node := node.(type) { + case *AliasedTableExpr: + if tblName, ok := node.Expr.(TableName); ok { + name := String(tblName) + if _, exists := tableMap[name]; !exists { + tableMap[name] = nil + tables = append(tables, name) + } + return false, nil + } + } + return true, nil + }, node) + return tables +} diff --git a/go/vt/sqlparser/ast_funcs_test.go b/go/vt/sqlparser/ast_funcs_test.go index 7bec47df96f..a3b744729f4 100644 --- a/go/vt/sqlparser/ast_funcs_test.go +++ b/go/vt/sqlparser/ast_funcs_test.go @@ -172,3 +172,50 @@ func TestColumns_Indexes(t *testing.T) { }) } } + +// TestExtractTables verifies the functionality of extracting all the tables from the SQLNode. +func TestExtractTables(t *testing.T) { + tcases := []struct { + sql string + expected []string + }{{ + sql: "select 1 from a", + expected: []string{"a"}, + }, { + sql: "select 1 from a, b", + expected: []string{"a", "b"}, + }, { + sql: "select 1 from a join b on a.id = b.id", + expected: []string{"a", "b"}, + }, { + sql: "select 1 from a join b on a.id = b.id join c on b.id = c.id", + expected: []string{"a", "b", "c"}, + }, { + sql: "select 1 from a join (select id from b) as c on a.id = c.id", + expected: []string{"a", "b"}, + }, { + sql: "(select 1 from a) union (select 1 from b)", + expected: []string{"a", "b"}, + }, { + sql: "select 1 from a where exists (select 1 from (select id from c) b where a.id = b.id)", + expected: []string{"a", "c"}, + }, { + sql: "select 1 from k.a join k.b on a.id = b.id", + expected: []string{"k.a", "k.b"}, + }, { + sql: "select 1 from k.a join l.a on k.a.id = l.a.id", + expected: []string{"k.a", "l.a"}, + }, { + sql: "select 1 from a join (select id from a) as c on a.id = c.id", + expected: []string{"a"}, + }} + parser := NewTestParser() + for _, tcase := range tcases { + t.Run(tcase.sql, func(t *testing.T) { + stmt, err := parser.Parse(tcase.sql) + require.NoError(t, err) + tables := ExtractAllTables(stmt) + require.Equal(t, tcase.expected, tables) + }) + } +}