From 939de976abaa318cc641eaca5cd6e3ac99533aeb Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Sat, 17 Feb 2024 15:13:48 +0100 Subject: [PATCH 1/3] vtexplain: Fix setting up the column information In https://github.com/vitessio/vitess/pull/13312 we fixed how we encode names in information_schema queries. The problem there though is that the `GetColumnNamesQuery` string was also used inside `vtexplain`. It's usage there was not updated for the new escaping logic. This in turn leads to broken queries during `vtexplain` which then leads to a lot of logs. The issue here is not that we log a lot, we log that we actually do have a problem here which is the case. This fixes the underlying issue of not correctly generating the query as we should. Signed-off-by: Dirkjan Bussink --- go/vt/mysqlctl/schema.go | 14 +++++++------- go/vt/vtexplain/vtexplain_vttablet.go | 8 ++++---- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/go/vt/mysqlctl/schema.go b/go/vt/mysqlctl/schema.go index c7ca98c4917..99d26dbb1db 100644 --- a/go/vt/mysqlctl/schema.go +++ b/go/vt/mysqlctl/schema.go @@ -66,7 +66,7 @@ func (mysqld *Mysqld) executeSchemaCommands(ctx context.Context, sql string) err return mysqld.executeMysqlScript(ctx, params, sql) } -func encodeEntityName(name string) string { +func EncodeEntityName(name string) string { var buf strings.Builder sqltypes.NewVarChar(name).EncodeSQL(&buf) return buf.String() @@ -80,7 +80,7 @@ func tableListSQL(tables []string) (string, error) { encodedTables := make([]string, len(tables)) for i, tableName := range tables { - encodedTables[i] = encodeEntityName(tableName) + encodedTables[i] = EncodeEntityName(tableName) } return "(" + strings.Join(encodedTables, ", ") + ")", nil @@ -307,13 +307,13 @@ func GetColumnsList(dbName, tableName string, exec func(string, int, bool) (*sql if dbName == "" { dbName2 = "database()" } else { - dbName2 = encodeEntityName(dbName) + dbName2 = EncodeEntityName(dbName) } sanitizedTableName, err := sqlescape.UnescapeID(tableName) if err != nil { return "", err } - query := fmt.Sprintf(GetColumnNamesQuery, dbName2, encodeEntityName(sanitizedTableName)) + query := fmt.Sprintf(GetColumnNamesQuery, dbName2, EncodeEntityName(sanitizedTableName)) qr, err := exec(query, -1, true) if err != nil { return "", err @@ -407,7 +407,7 @@ func (mysqld *Mysqld) getPrimaryKeyColumns(ctx context.Context, dbName string, t FROM information_schema.STATISTICS WHERE TABLE_SCHEMA = %s AND TABLE_NAME IN %s AND LOWER(INDEX_NAME) = 'primary' ORDER BY table_name, SEQ_IN_INDEX` - sql = fmt.Sprintf(sql, encodeEntityName(dbName), tableList) + sql = fmt.Sprintf(sql, EncodeEntityName(dbName), tableList) qr, err := conn.Conn.ExecuteFetch(sql, len(tables)*100, true) if err != nil { return nil, err @@ -631,8 +631,8 @@ func GetPrimaryKeyEquivalentColumns(ctx context.Context, exec func(string, int, ) AS pke ON index_cols.INDEX_NAME = pke.INDEX_NAME WHERE index_cols.TABLE_SCHEMA = %s AND index_cols.TABLE_NAME = %s AND NON_UNIQUE = 0 AND NULLABLE != 'YES' ORDER BY SEQ_IN_INDEX ASC` - encodedDbName := encodeEntityName(dbName) - encodedTable := encodeEntityName(table) + encodedDbName := EncodeEntityName(dbName) + encodedTable := EncodeEntityName(table) sql = fmt.Sprintf(sql, encodedDbName, encodedTable, encodedDbName, encodedTable, encodedDbName, encodedTable) qr, err := exec(sql, 1000, true) if err != nil { diff --git a/go/vt/vtexplain/vtexplain_vttablet.go b/go/vt/vtexplain/vtexplain_vttablet.go index b04365a3d0a..72fa92cc817 100644 --- a/go/vt/vtexplain/vtexplain_vttablet.go +++ b/go/vt/vtexplain/vtexplain_vttablet.go @@ -474,8 +474,8 @@ func newTabletEnvironment(ddls []sqlparser.DDLStatement, opts *Options, collatio } tEnv.addResult(query, tEnv.getResult(likeQuery)) - likeQuery = fmt.Sprintf(mysqlctl.GetColumnNamesQuery, "database()", sanitizedLikeTable) - query = fmt.Sprintf(mysqlctl.GetColumnNamesQuery, "database()", sanitizedTable) + likeQuery = fmt.Sprintf(mysqlctl.GetColumnNamesQuery, "database()", mysqlctl.EncodeEntityName(sanitizedLikeTable)) + query = fmt.Sprintf(mysqlctl.GetColumnNamesQuery, "database()", mysqlctl.EncodeEntityName(sanitizedTable)) if tEnv.getResult(likeQuery) == nil { return nil, fmt.Errorf("check your schema, table[%s] doesn't exist", likeTable) } @@ -516,7 +516,7 @@ func newTabletEnvironment(ddls []sqlparser.DDLStatement, opts *Options, collatio tEnv.addResult("SELECT * FROM "+backtickedTable+" WHERE 1 != 1", &sqltypes.Result{ Fields: rowTypes, }) - query := fmt.Sprintf(mysqlctl.GetColumnNamesQuery, "database()", sanitizedTable) + query := fmt.Sprintf(mysqlctl.GetColumnNamesQuery, "database()", mysqlctl.EncodeEntityName(sanitizedTable)) tEnv.addResult(query, &sqltypes.Result{ Fields: colTypes, Rows: colValues, @@ -618,7 +618,7 @@ func (t *explainTablet) handleSelect(query string) (*sqltypes.Result, error) { // Gen4 supports more complex queries so we now need to // handle multiple FROM clauses - tables := make([]*sqlparser.AliasedTableExpr, len(selStmt.From)) + tables := make([]*sqlparser.AliasedTableExpr, 0, len(selStmt.From)) for _, from := range selStmt.From { tables = append(tables, getTables(from)...) } From 241d5f5ef47006777dc18ce216d9d501cb1e7cde Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Sun, 18 Feb 2024 15:29:59 +0100 Subject: [PATCH 2/3] Add basic test Signed-off-by: Dirkjan Bussink --- go/vt/vtexplain/vtexplain_vttablet_test.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/go/vt/vtexplain/vtexplain_vttablet_test.go b/go/vt/vtexplain/vtexplain_vttablet_test.go index 8de1a146ce6..fc3785a907a 100644 --- a/go/vt/vtexplain/vtexplain_vttablet_test.go +++ b/go/vt/vtexplain/vtexplain_vttablet_test.go @@ -77,6 +77,10 @@ create table t2 ( require.NoError(t, err) defer vte.Stop() + // Check if the correct schema query is registered. + _, found := vte.globalTabletEnv.schemaQueries["SELECT COLUMN_NAME as column_name\n\t\tFROM INFORMATION_SCHEMA.COLUMNS\n\t\tWHERE TABLE_SCHEMA = database() AND TABLE_NAME = 't1'\n\t\tORDER BY ORDINAL_POSITION"] + assert.True(t, found) + sql := "SELECT * FROM t1 INNER JOIN t2 ON t1.id = t2.id" _, err = vte.Run(sql) From 4f04f201f2eb17e15d4561a3a25d4ecfb533d27b Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Mon, 19 Feb 2024 10:24:55 +0100 Subject: [PATCH 3/3] Remove unneeded duplicate function There's already `sqltypes.EncodeStringSQL` which implements the exact same thing. Signed-off-by: Dirkjan Bussink --- go/vt/mysqlctl/schema.go | 18 ++++++------------ go/vt/vtexplain/vtexplain_vttablet.go | 6 +++--- 2 files changed, 9 insertions(+), 15 deletions(-) diff --git a/go/vt/mysqlctl/schema.go b/go/vt/mysqlctl/schema.go index 99d26dbb1db..a5ac38670e7 100644 --- a/go/vt/mysqlctl/schema.go +++ b/go/vt/mysqlctl/schema.go @@ -66,12 +66,6 @@ func (mysqld *Mysqld) executeSchemaCommands(ctx context.Context, sql string) err return mysqld.executeMysqlScript(ctx, params, sql) } -func EncodeEntityName(name string) string { - var buf strings.Builder - sqltypes.NewVarChar(name).EncodeSQL(&buf) - return buf.String() -} - // tableListSQL returns an IN clause "('t1', 't2'...) for a list of tables." func tableListSQL(tables []string) (string, error) { if len(tables) == 0 { @@ -80,7 +74,7 @@ func tableListSQL(tables []string) (string, error) { encodedTables := make([]string, len(tables)) for i, tableName := range tables { - encodedTables[i] = EncodeEntityName(tableName) + encodedTables[i] = sqltypes.EncodeStringSQL(tableName) } return "(" + strings.Join(encodedTables, ", ") + ")", nil @@ -307,13 +301,13 @@ func GetColumnsList(dbName, tableName string, exec func(string, int, bool) (*sql if dbName == "" { dbName2 = "database()" } else { - dbName2 = EncodeEntityName(dbName) + dbName2 = sqltypes.EncodeStringSQL(dbName) } sanitizedTableName, err := sqlescape.UnescapeID(tableName) if err != nil { return "", err } - query := fmt.Sprintf(GetColumnNamesQuery, dbName2, EncodeEntityName(sanitizedTableName)) + query := fmt.Sprintf(GetColumnNamesQuery, dbName2, sqltypes.EncodeStringSQL(sanitizedTableName)) qr, err := exec(query, -1, true) if err != nil { return "", err @@ -407,7 +401,7 @@ func (mysqld *Mysqld) getPrimaryKeyColumns(ctx context.Context, dbName string, t FROM information_schema.STATISTICS WHERE TABLE_SCHEMA = %s AND TABLE_NAME IN %s AND LOWER(INDEX_NAME) = 'primary' ORDER BY table_name, SEQ_IN_INDEX` - sql = fmt.Sprintf(sql, EncodeEntityName(dbName), tableList) + sql = fmt.Sprintf(sql, sqltypes.EncodeStringSQL(dbName), tableList) qr, err := conn.Conn.ExecuteFetch(sql, len(tables)*100, true) if err != nil { return nil, err @@ -631,8 +625,8 @@ func GetPrimaryKeyEquivalentColumns(ctx context.Context, exec func(string, int, ) AS pke ON index_cols.INDEX_NAME = pke.INDEX_NAME WHERE index_cols.TABLE_SCHEMA = %s AND index_cols.TABLE_NAME = %s AND NON_UNIQUE = 0 AND NULLABLE != 'YES' ORDER BY SEQ_IN_INDEX ASC` - encodedDbName := EncodeEntityName(dbName) - encodedTable := EncodeEntityName(table) + encodedDbName := sqltypes.EncodeStringSQL(dbName) + encodedTable := sqltypes.EncodeStringSQL(table) sql = fmt.Sprintf(sql, encodedDbName, encodedTable, encodedDbName, encodedTable, encodedDbName, encodedTable) qr, err := exec(sql, 1000, true) if err != nil { diff --git a/go/vt/vtexplain/vtexplain_vttablet.go b/go/vt/vtexplain/vtexplain_vttablet.go index 72fa92cc817..9ac81000f3d 100644 --- a/go/vt/vtexplain/vtexplain_vttablet.go +++ b/go/vt/vtexplain/vtexplain_vttablet.go @@ -474,8 +474,8 @@ func newTabletEnvironment(ddls []sqlparser.DDLStatement, opts *Options, collatio } tEnv.addResult(query, tEnv.getResult(likeQuery)) - likeQuery = fmt.Sprintf(mysqlctl.GetColumnNamesQuery, "database()", mysqlctl.EncodeEntityName(sanitizedLikeTable)) - query = fmt.Sprintf(mysqlctl.GetColumnNamesQuery, "database()", mysqlctl.EncodeEntityName(sanitizedTable)) + likeQuery = fmt.Sprintf(mysqlctl.GetColumnNamesQuery, "database()", sqltypes.EncodeStringSQL(sanitizedLikeTable)) + query = fmt.Sprintf(mysqlctl.GetColumnNamesQuery, "database()", sqltypes.EncodeStringSQL(sanitizedTable)) if tEnv.getResult(likeQuery) == nil { return nil, fmt.Errorf("check your schema, table[%s] doesn't exist", likeTable) } @@ -516,7 +516,7 @@ func newTabletEnvironment(ddls []sqlparser.DDLStatement, opts *Options, collatio tEnv.addResult("SELECT * FROM "+backtickedTable+" WHERE 1 != 1", &sqltypes.Result{ Fields: rowTypes, }) - query := fmt.Sprintf(mysqlctl.GetColumnNamesQuery, "database()", mysqlctl.EncodeEntityName(sanitizedTable)) + query := fmt.Sprintf(mysqlctl.GetColumnNamesQuery, "database()", sqltypes.EncodeStringSQL(sanitizedTable)) tEnv.addResult(query, &sqltypes.Result{ Fields: colTypes, Rows: colValues,