diff --git a/catalog/internal_macro.go b/catalog/internal_macro.go index 3a99800..27852ad 100644 --- a/catalog/internal_macro.go +++ b/catalog/internal_macro.go @@ -7,6 +7,13 @@ type MacroDefinition struct { DDL string } +var ( + SchemaNameSYS string = "__sys__" + MacroNameMyListContains string = "my_list_contains" + + MacroNameMySplitListStr string = "my_split_list_str" +) + type InternalMacro struct { Schema string Name string @@ -55,4 +62,48 @@ var InternalMacros = []InternalMacro{ }, }, }, + { + Schema: "pg_catalog", + Name: "pg_get_expr", + IsTableMacro: false, + Definitions: []MacroDefinition{ + { + Params: []string{"pg_node_tree", "relation_oid"}, + // Do nothing currently + DDL: `pg_catalog.pg_get_expr(pg_node_tree, relation_oid)`, + }, + { + Params: []string{"pg_node_tree", "relation_oid", "pretty_bool"}, + // Do nothing currently + DDL: `pg_catalog.pg_get_expr(pg_node_tree, relation_oid)`, + }, + }, + }, + { + Schema: SchemaNameSYS, + Name: MacroNameMyListContains, + IsTableMacro: false, + Definitions: []MacroDefinition{ + { + Params: []string{"l", "v"}, + DDL: `CASE + WHEN typeof(l) = 'VARCHAR' THEN + list_contains(regexp_split_to_array(l::VARCHAR, '[{},\s]+'), v) + ELSE + list_contains(l::text[], v) + END`, + }, + }, + }, + { + Schema: SchemaNameSYS, + Name: MacroNameMySplitListStr, + IsTableMacro: false, + Definitions: []MacroDefinition{ + { + Params: []string{"l"}, + DDL: `regexp_split_to_array(l::VARCHAR, '[{},\s]+')`, + }, + }, + }, } diff --git a/pgserver/in_place_handler.go b/pgserver/in_place_handler.go index 6e70488..329d6bc 100644 --- a/pgserver/in_place_handler.go +++ b/pgserver/in_place_handler.go @@ -237,11 +237,24 @@ var selectionConversions = []SelectionConversion{ needConvert: func(query *ConvertedStatement) bool { sqlStr := RemoveComments(query.String) // TODO(sean): Evaluate the conditions by iterating over the AST. - return getTypeCastRegex().MatchString(sqlStr) + return getSimpleStringMatchingRegex().MatchString(sqlStr) }, doConvert: func(h *ConnectionHandler, query *ConvertedStatement) error { sqlStr := RemoveComments(query.String) - sqlStr = ConvertTypeCast(sqlStr) + sqlStr = SimpleStrReplacement(sqlStr) + query.String = sqlStr + return nil + }, + }, + { + needConvert: func(query *ConvertedStatement) bool { + sqlStr := RemoveComments(query.String) + // TODO: Evaluate the conditions by iterating over the AST. + return getPgAnyOpRegex().MatchString(sqlStr) + }, + doConvert: func(h *ConnectionHandler, query *ConvertedStatement) error { + sqlStr := RemoveComments(query.String) + sqlStr = ConvertAnyOp(sqlStr) query.String = sqlStr return nil }, diff --git a/pgserver/stmt.go b/pgserver/stmt.go index 4471d17..2db3ff1 100644 --- a/pgserver/stmt.go +++ b/pgserver/stmt.go @@ -295,31 +295,57 @@ func ConvertToSys(sql string) string { } var ( - typeCastRegex *regexp.Regexp - initTypeCastRegex sync.Once + pgAnyOpRegex *regexp.Regexp + initPgAnyOpRegex sync.Once ) +// get the regex to match the operator 'ANY' +func getPgAnyOpRegex() *regexp.Regexp { + initPgAnyOpRegex.Do(func() { + pgAnyOpRegex = regexp.MustCompile(`(?i)([^\s(]+)\s*=\s*any\s*\(\s*([^)]*)\s*\)`) + }) + return pgAnyOpRegex +} + +// Replace the operator 'ANY' with a function call. +func ConvertAnyOp(sql string) string { + re := getPgAnyOpRegex() + return re.ReplaceAllString(sql, catalog.SchemaNameSYS+"."+catalog.MacroNameMyListContains+"($2, $1)") +} + +var ( + simpleStrMatchingRegex *regexp.Regexp + initSimpleStrMatchingRegex sync.Once +) + +// TODO(sean): This is a temporary solution. We need to find a better way to handle type cast conversion and column conversion. e.g. Iterating the AST with a visitor pattern. // The Key must be in lowercase. Because the key used for value retrieval is in lowercase. -var typeCastConversion = map[string]string{ +var simpleStringsConversion = map[string]string{ + // type cast conversion "::regclass": "::varchar", + "::regtype": "::varchar", + + // column conversion + "proallargtypes": catalog.SchemaNameSYS + "." + catalog.MacroNameMySplitListStr + "(proallargtypes)", + "proargtypes": catalog.SchemaNameSYS + "." + catalog.MacroNameMySplitListStr + "(proargtypes)", } // This function will return a regex that matches all type casts in the query. -func getTypeCastRegex() *regexp.Regexp { - initTypeCastRegex.Do(func() { - var typeCasts []string - for typeCast := range typeCastConversion { - typeCasts = append(typeCasts, regexp.QuoteMeta(typeCast)) +func getSimpleStringMatchingRegex() *regexp.Regexp { + initSimpleStrMatchingRegex.Do(func() { + var simpleStrings []string + for simpleString := range simpleStringsConversion { + simpleStrings = append(simpleStrings, regexp.QuoteMeta(simpleString)) } - typeCastRegex = regexp.MustCompile(`(?i)(` + strings.Join(typeCasts, "|") + `)`) + simpleStrMatchingRegex = regexp.MustCompile(`(?i)(` + strings.Join(simpleStrings, "|") + `)`) }) - return typeCastRegex + return simpleStrMatchingRegex } -// This function will replace all type casts in the query with the corresponding type cast in the typeCastConversion map. -func ConvertTypeCast(sql string) string { - return getTypeCastRegex().ReplaceAllStringFunc(sql, func(m string) string { - return typeCastConversion[strings.ToLower(m)] +// This function will replace all type casts in the query with the corresponding type cast in the simpleStringsConversion map. +func SimpleStrReplacement(sql string) string { + return getSimpleStringMatchingRegex().ReplaceAllStringFunc(sql, func(m string) string { + return simpleStringsConversion[strings.ToLower(m)] }) }