From bf9d0643583442a5496160653d52deadd4bd822f Mon Sep 17 00:00:00 2001
From: Manan Gupta <35839558+GuptaManan100@users.noreply.github.com>
Date: Wed, 4 Sep 2024 11:41:30 +0530
Subject: [PATCH] Fix ACL checks for CTEs (#16642)

Signed-off-by: Manan Gupta <manan@planetscale.com>
---
 .../tabletserver/planbuilder/permission.go    | 89 ++++++++++++++-----
 .../planbuilder/permission_test.go            | 39 ++++++++
 2 files changed, 108 insertions(+), 20 deletions(-)

diff --git a/go/vt/vttablet/tabletserver/planbuilder/permission.go b/go/vt/vttablet/tabletserver/planbuilder/permission.go
index dbc6cfccdad..1949d6ce739 100644
--- a/go/vt/vttablet/tabletserver/planbuilder/permission.go
+++ b/go/vt/vttablet/tabletserver/planbuilder/permission.go
@@ -45,17 +45,17 @@ func BuildPermissions(stmt sqlparser.Statement) []Permission {
 	case *sqlparser.Union:
 		permissions = buildSubqueryPermissions(node, tableacl.READER, permissions)
 	case *sqlparser.Insert:
-		permissions = buildTableExprPermissions(node.Table, tableacl.WRITER, permissions)
+		permissions = buildTableExprPermissions(node.Table, tableacl.WRITER, nil, permissions)
 		permissions = buildSubqueryPermissions(node, tableacl.READER, permissions)
 	case *sqlparser.Update:
-		permissions = buildTableExprsPermissions(node.TableExprs, tableacl.WRITER, permissions)
+		permissions = buildTableExprsPermissions(node.TableExprs, tableacl.WRITER, nil, permissions)
 		permissions = buildSubqueryPermissions(node, tableacl.READER, permissions)
 	case *sqlparser.Delete:
-		permissions = buildTableExprsPermissions(node.TableExprs, tableacl.WRITER, permissions)
+		permissions = buildTableExprsPermissions(node.TableExprs, tableacl.WRITER, nil, permissions)
 		permissions = buildSubqueryPermissions(node, tableacl.READER, permissions)
 	case sqlparser.DDLStatement:
 		for _, t := range node.AffectedTables() {
-			permissions = buildTableNamePermissions(t, tableacl.ADMIN, permissions)
+			permissions = buildTableNamePermissions(t, tableacl.ADMIN, nil, permissions)
 		}
 	case
 		*sqlparser.AlterMigration,
@@ -66,10 +66,10 @@ func BuildPermissions(stmt sqlparser.Statement) []Permission {
 		permissions = []Permission{} // TODO(shlomi) what are the correct permissions here? Table is unknown
 	case *sqlparser.Flush:
 		for _, t := range node.TableNames {
-			permissions = buildTableNamePermissions(t, tableacl.ADMIN, permissions)
+			permissions = buildTableNamePermissions(t, tableacl.ADMIN, nil, permissions)
 		}
 	case *sqlparser.Analyze:
-		permissions = buildTableNamePermissions(node.Table, tableacl.WRITER, permissions)
+		permissions = buildTableNamePermissions(node.Table, tableacl.WRITER, nil, permissions)
 	case *sqlparser.OtherAdmin, *sqlparser.CallProc, *sqlparser.Begin, *sqlparser.Commit, *sqlparser.Rollback,
 		*sqlparser.Load, *sqlparser.Savepoint, *sqlparser.Release, *sqlparser.SRollback, *sqlparser.Set, *sqlparser.Show, sqlparser.Explain,
 		*sqlparser.UnlockTables:
@@ -81,43 +81,92 @@ func BuildPermissions(stmt sqlparser.Statement) []Permission {
 }
 
 func buildSubqueryPermissions(stmt sqlparser.Statement, role tableacl.Role, permissions []Permission) []Permission {
-	_ = sqlparser.Walk(func(node sqlparser.SQLNode) (bool, error) {
-		if sel, ok := node.(*sqlparser.Select); ok {
-			permissions = buildTableExprsPermissions(sel.From, role, permissions)
+	var cteScopes [][]sqlparser.IdentifierCS
+	sqlparser.Rewrite(stmt, func(cursor *sqlparser.Cursor) bool {
+		switch node := cursor.Node().(type) {
+		case *sqlparser.Select:
+			if node.With != nil {
+				cteScopes = append(cteScopes, gatherCTEs(node.With))
+			}
+			var ctes []sqlparser.IdentifierCS
+			for _, cteScope := range cteScopes {
+				ctes = append(ctes, cteScope...)
+			}
+			permissions = buildTableExprsPermissions(node.From, role, ctes, permissions)
+		case *sqlparser.Delete:
+			if node.With != nil {
+				cteScopes = append(cteScopes, gatherCTEs(node.With))
+			}
+		case *sqlparser.Update:
+			if node.With != nil {
+				cteScopes = append(cteScopes, gatherCTEs(node.With))
+			}
+		case *sqlparser.Union:
+			if node.With != nil {
+				cteScopes = append(cteScopes, gatherCTEs(node.With))
+			}
 		}
-		return true, nil
-	}, stmt)
+		return true
+	}, func(cursor *sqlparser.Cursor) bool {
+		// When we encounter a With expression coming up, we should remove
+		// the last value from the cte scopes to ensure we none of the outer
+		// elements of the query see this table name.
+		_, isWith := cursor.Node().(*sqlparser.With)
+		if isWith {
+			cteScopes = cteScopes[:len(cteScopes)-1]
+		}
+		return true
+	})
 	return permissions
 }
 
-func buildTableExprsPermissions(node []sqlparser.TableExpr, role tableacl.Role, permissions []Permission) []Permission {
+// gatherCTEs gathers the CTEs from the WITH clause.
+func gatherCTEs(with *sqlparser.With) []sqlparser.IdentifierCS {
+	var ctes []sqlparser.IdentifierCS
+	for _, cte := range with.CTEs {
+		ctes = append(ctes, cte.ID)
+	}
+	return ctes
+}
+
+func buildTableExprsPermissions(node []sqlparser.TableExpr, role tableacl.Role, ctes []sqlparser.IdentifierCS, permissions []Permission) []Permission {
 	for _, node := range node {
-		permissions = buildTableExprPermissions(node, role, permissions)
+		permissions = buildTableExprPermissions(node, role, ctes, permissions)
 	}
 	return permissions
 }
 
-func buildTableExprPermissions(node sqlparser.TableExpr, role tableacl.Role, permissions []Permission) []Permission {
+func buildTableExprPermissions(node sqlparser.TableExpr, role tableacl.Role, ctes []sqlparser.IdentifierCS, permissions []Permission) []Permission {
 	switch node := node.(type) {
 	case *sqlparser.AliasedTableExpr:
 		// An AliasedTableExpr can also be a derived table, but we should skip them here
 		// because the buildSubQueryPermissions walker will catch them and extract
 		// the corresponding table names.
 		if tblName, ok := node.Expr.(sqlparser.TableName); ok {
-			permissions = buildTableNamePermissions(tblName, role, permissions)
+			permissions = buildTableNamePermissions(tblName, role, ctes, permissions)
 		}
 	case *sqlparser.ParenTableExpr:
-		permissions = buildTableExprsPermissions(node.Exprs, role, permissions)
+		permissions = buildTableExprsPermissions(node.Exprs, role, ctes, permissions)
 	case *sqlparser.JoinTableExpr:
-		permissions = buildTableExprPermissions(node.LeftExpr, role, permissions)
-		permissions = buildTableExprPermissions(node.RightExpr, role, permissions)
+		permissions = buildTableExprPermissions(node.LeftExpr, role, ctes, permissions)
+		permissions = buildTableExprPermissions(node.RightExpr, role, ctes, permissions)
 	}
 	return permissions
 }
 
-func buildTableNamePermissions(node sqlparser.TableName, role tableacl.Role, permissions []Permission) []Permission {
+func buildTableNamePermissions(node sqlparser.TableName, role tableacl.Role, ctes []sqlparser.IdentifierCS, permissions []Permission) []Permission {
+	tableName := node.Name.String()
+	// Check whether this table is a cte or not.
+	// If the table name is qualified, then it cannot be a cte.
+	if node.Qualifier.IsEmpty() {
+		for _, cte := range ctes {
+			if cte.String() == tableName {
+				return permissions
+			}
+		}
+	}
 	permissions = append(permissions, Permission{
-		TableName: node.Name.String(),
+		TableName: tableName,
 		Role:      role,
 	})
 	return permissions
diff --git a/go/vt/vttablet/tabletserver/planbuilder/permission_test.go b/go/vt/vttablet/tabletserver/planbuilder/permission_test.go
index 0ece6ed19b2..ab238661664 100644
--- a/go/vt/vttablet/tabletserver/planbuilder/permission_test.go
+++ b/go/vt/vttablet/tabletserver/planbuilder/permission_test.go
@@ -180,6 +180,45 @@ func TestBuildPermissions(t *testing.T) {
 			TableName: "seq",
 			Role:      tableacl.WRITER,
 		}},
+	}, {
+		input: "with t as (select count(*) as a from user) select a from t",
+		output: []Permission{{
+			TableName: "user",
+			Role:      tableacl.READER,
+		}},
+	}, {
+		input: "with d as (select id, count(*) as a from user) select d.a from music join d on music.user_id = d.id group by 1",
+		output: []Permission{{
+			TableName: "music",
+			Role:      tableacl.READER,
+		}, {
+			TableName: "user",
+			Role:      tableacl.READER,
+		}},
+	}, {
+		input: "WITH t1 AS ( SELECT id FROM t2 ) SELECT * FROM t1 JOIN ks.t1 AS t3",
+		output: []Permission{{
+			TableName: "t1",
+			Role:      tableacl.READER,
+		}, {
+			TableName: "t2",
+			Role:      tableacl.READER,
+		}},
+	}, {
+		input: "WITH RECURSIVE t1 (n) AS ( SELECT id from t2 UNION ALL SELECT n + 1 FROM t1 WHERE n < 5 ) SELECT * FROM t1 JOIN t1 AS t3",
+		output: []Permission{{
+			TableName: "t2",
+			Role:      tableacl.READER,
+		}},
+	}, {
+		input: "(with t1 as (select count(*) as a from user) select a from t1) union  select * from t1",
+		output: []Permission{{
+			TableName: "user",
+			Role:      tableacl.READER,
+		}, {
+			TableName: "t1",
+			Role:      tableacl.READER,
+		}},
 	}}
 
 	for _, tcase := range tcases {