Skip to content

Commit

Permalink
fix: solve some compatibility issues with pgcli (#355)
Browse files Browse the repository at this point in the history
* fix: replace 'xxx=ANY(yyy)' with 'my_list_contains(yyy, xxx) (#354)

* fix: support pg_catalog.pg_get_expr with 3 params (#354)

* fix: wrap columns 'proallargtypes' and 'proargtypes' to split string into string array (#354)
  • Loading branch information
VWagen1989 authored Jan 17, 2025
1 parent 34b0213 commit 068cd59
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 16 deletions.
51 changes: 51 additions & 0 deletions catalog/internal_macro.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@ type MacroDefinition struct {
DDL string
}

var (
SchemaNameSYS string = "__sys__"
MacroNameMyListContains string = "my_list_contains"

MacroNameMySplitListStr string = "my_split_list_str"
)

type InternalMacro struct {
Schema string
Name string
Expand Down Expand Up @@ -55,4 +62,48 @@ var InternalMacros = []InternalMacro{
},
},
},
{
Schema: "pg_catalog",
Name: "pg_get_expr",
IsTableMacro: false,
Definitions: []MacroDefinition{
{
Params: []string{"pg_node_tree", "relation_oid"},
// Do nothing currently
DDL: `pg_catalog.pg_get_expr(pg_node_tree, relation_oid)`,
},
{
Params: []string{"pg_node_tree", "relation_oid", "pretty_bool"},
// Do nothing currently
DDL: `pg_catalog.pg_get_expr(pg_node_tree, relation_oid)`,
},
},
},
{
Schema: SchemaNameSYS,
Name: MacroNameMyListContains,
IsTableMacro: false,
Definitions: []MacroDefinition{
{
Params: []string{"l", "v"},
DDL: `CASE
WHEN typeof(l) = 'VARCHAR' THEN
list_contains(regexp_split_to_array(l::VARCHAR, '[{},\s]+'), v)
ELSE
list_contains(l::text[], v)
END`,
},
},
},
{
Schema: SchemaNameSYS,
Name: MacroNameMySplitListStr,
IsTableMacro: false,
Definitions: []MacroDefinition{
{
Params: []string{"l"},
DDL: `regexp_split_to_array(l::VARCHAR, '[{},\s]+')`,
},
},
},
}
17 changes: 15 additions & 2 deletions pgserver/in_place_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,11 +237,24 @@ var selectionConversions = []SelectionConversion{
needConvert: func(query *ConvertedStatement) bool {
sqlStr := RemoveComments(query.String)
// TODO(sean): Evaluate the conditions by iterating over the AST.
return getTypeCastRegex().MatchString(sqlStr)
return getSimpleStringMatchingRegex().MatchString(sqlStr)
},
doConvert: func(h *ConnectionHandler, query *ConvertedStatement) error {
sqlStr := RemoveComments(query.String)
sqlStr = ConvertTypeCast(sqlStr)
sqlStr = SimpleStrReplacement(sqlStr)
query.String = sqlStr
return nil
},
},
{
needConvert: func(query *ConvertedStatement) bool {
sqlStr := RemoveComments(query.String)
// TODO: Evaluate the conditions by iterating over the AST.
return getPgAnyOpRegex().MatchString(sqlStr)
},
doConvert: func(h *ConnectionHandler, query *ConvertedStatement) error {
sqlStr := RemoveComments(query.String)
sqlStr = ConvertAnyOp(sqlStr)
query.String = sqlStr
return nil
},
Expand Down
54 changes: 40 additions & 14 deletions pgserver/stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -295,31 +295,57 @@ func ConvertToSys(sql string) string {
}

var (
typeCastRegex *regexp.Regexp
initTypeCastRegex sync.Once
pgAnyOpRegex *regexp.Regexp
initPgAnyOpRegex sync.Once
)

// get the regex to match the operator 'ANY'
func getPgAnyOpRegex() *regexp.Regexp {
initPgAnyOpRegex.Do(func() {
pgAnyOpRegex = regexp.MustCompile(`(?i)([^\s(]+)\s*=\s*any\s*\(\s*([^)]*)\s*\)`)
})
return pgAnyOpRegex
}

// Replace the operator 'ANY' with a function call.
func ConvertAnyOp(sql string) string {
re := getPgAnyOpRegex()
return re.ReplaceAllString(sql, catalog.SchemaNameSYS+"."+catalog.MacroNameMyListContains+"($2, $1)")
}

var (
simpleStrMatchingRegex *regexp.Regexp
initSimpleStrMatchingRegex sync.Once
)

// TODO(sean): This is a temporary solution. We need to find a better way to handle type cast conversion and column conversion. e.g. Iterating the AST with a visitor pattern.
// The Key must be in lowercase. Because the key used for value retrieval is in lowercase.
var typeCastConversion = map[string]string{
var simpleStringsConversion = map[string]string{
// type cast conversion
"::regclass": "::varchar",
"::regtype": "::varchar",

// column conversion
"proallargtypes": catalog.SchemaNameSYS + "." + catalog.MacroNameMySplitListStr + "(proallargtypes)",
"proargtypes": catalog.SchemaNameSYS + "." + catalog.MacroNameMySplitListStr + "(proargtypes)",
}

// This function will return a regex that matches all type casts in the query.
func getTypeCastRegex() *regexp.Regexp {
initTypeCastRegex.Do(func() {
var typeCasts []string
for typeCast := range typeCastConversion {
typeCasts = append(typeCasts, regexp.QuoteMeta(typeCast))
func getSimpleStringMatchingRegex() *regexp.Regexp {
initSimpleStrMatchingRegex.Do(func() {
var simpleStrings []string
for simpleString := range simpleStringsConversion {
simpleStrings = append(simpleStrings, regexp.QuoteMeta(simpleString))
}
typeCastRegex = regexp.MustCompile(`(?i)(` + strings.Join(typeCasts, "|") + `)`)
simpleStrMatchingRegex = regexp.MustCompile(`(?i)(` + strings.Join(simpleStrings, "|") + `)`)
})
return typeCastRegex
return simpleStrMatchingRegex
}

// This function will replace all type casts in the query with the corresponding type cast in the typeCastConversion map.
func ConvertTypeCast(sql string) string {
return getTypeCastRegex().ReplaceAllStringFunc(sql, func(m string) string {
return typeCastConversion[strings.ToLower(m)]
// This function will replace all type casts in the query with the corresponding type cast in the simpleStringsConversion map.
func SimpleStrReplacement(sql string) string {
return getSimpleStringMatchingRegex().ReplaceAllStringFunc(sql, func(m string) string {
return simpleStringsConversion[strings.ToLower(m)]
})
}

Expand Down

0 comments on commit 068cd59

Please sign in to comment.