Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: solve some compatibility issues with pgcli #355

Merged
merged 3 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions catalog/internal_macro.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@ type MacroDefinition struct {
DDL string
}

var (
SchemaNameSYS string = "__sys__"
MacroNameMyListContains string = "my_list_contains"

MacroNameMySplitListStr string = "my_split_list_str"
)

type InternalMacro struct {
Schema string
Name string
Expand Down Expand Up @@ -55,4 +62,48 @@ var InternalMacros = []InternalMacro{
},
},
},
{
Schema: "pg_catalog",
Name: "pg_get_expr",
IsTableMacro: false,
Definitions: []MacroDefinition{
{
Params: []string{"pg_node_tree", "relation_oid"},
// Do nothing currently
DDL: `pg_catalog.pg_get_expr(pg_node_tree, relation_oid)`,
},
{
Params: []string{"pg_node_tree", "relation_oid", "pretty_bool"},
// Do nothing currently
DDL: `pg_catalog.pg_get_expr(pg_node_tree, relation_oid)`,
},
},
},
{
Schema: SchemaNameSYS,
Name: MacroNameMyListContains,
IsTableMacro: false,
Definitions: []MacroDefinition{
{
Params: []string{"l", "v"},
DDL: `CASE
WHEN typeof(l) = 'VARCHAR' THEN
list_contains(regexp_split_to_array(l::VARCHAR, '[{},\s]+'), v)
ELSE
list_contains(l::text[], v)
END`,
},
},
},
{
Schema: SchemaNameSYS,
Name: MacroNameMySplitListStr,
IsTableMacro: false,
Definitions: []MacroDefinition{
{
Params: []string{"l"},
DDL: `regexp_split_to_array(l::VARCHAR, '[{},\s]+')`,
},
},
},
}
17 changes: 15 additions & 2 deletions pgserver/in_place_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,11 +237,24 @@ var selectionConversions = []SelectionConversion{
needConvert: func(query *ConvertedStatement) bool {
sqlStr := RemoveComments(query.String)
// TODO(sean): Evaluate the conditions by iterating over the AST.
return getTypeCastRegex().MatchString(sqlStr)
return getSimpleStringMatchingRegex().MatchString(sqlStr)
},
doConvert: func(h *ConnectionHandler, query *ConvertedStatement) error {
sqlStr := RemoveComments(query.String)
sqlStr = ConvertTypeCast(sqlStr)
sqlStr = SimpleStrReplacement(sqlStr)
query.String = sqlStr
return nil
},
},
{
needConvert: func(query *ConvertedStatement) bool {
sqlStr := RemoveComments(query.String)
// TODO: Evaluate the conditions by iterating over the AST.
return getPgAnyOpRegex().MatchString(sqlStr)
},
doConvert: func(h *ConnectionHandler, query *ConvertedStatement) error {
sqlStr := RemoveComments(query.String)
sqlStr = ConvertAnyOp(sqlStr)
query.String = sqlStr
return nil
},
Expand Down
54 changes: 40 additions & 14 deletions pgserver/stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -295,31 +295,57 @@ func ConvertToSys(sql string) string {
}

var (
typeCastRegex *regexp.Regexp
initTypeCastRegex sync.Once
pgAnyOpRegex *regexp.Regexp
initPgAnyOpRegex sync.Once
)

// get the regex to match the operator 'ANY'
func getPgAnyOpRegex() *regexp.Regexp {
initPgAnyOpRegex.Do(func() {
pgAnyOpRegex = regexp.MustCompile(`(?i)([^\s(]+)\s*=\s*any\s*\(\s*([^)]*)\s*\)`)
})
return pgAnyOpRegex
}

// Replace the operator 'ANY' with a function call.
func ConvertAnyOp(sql string) string {
re := getPgAnyOpRegex()
return re.ReplaceAllString(sql, catalog.SchemaNameSYS+"."+catalog.MacroNameMyListContains+"($2, $1)")
}

var (
simpleStrMatchingRegex *regexp.Regexp
initSimpleStrMatchingRegex sync.Once
)

// TODO(sean): This is a temporary solution. We need to find a better way to handle type cast conversion and column conversion. e.g. Iterating the AST with a visitor pattern.
// The Key must be in lowercase. Because the key used for value retrieval is in lowercase.
var typeCastConversion = map[string]string{
var simpleStringsConversion = map[string]string{
// type cast conversion
"::regclass": "::varchar",
"::regtype": "::varchar",

// column conversion
"proallargtypes": catalog.SchemaNameSYS + "." + catalog.MacroNameMySplitListStr + "(proallargtypes)",
"proargtypes": catalog.SchemaNameSYS + "." + catalog.MacroNameMySplitListStr + "(proargtypes)",
}

// This function will return a regex that matches all type casts in the query.
func getTypeCastRegex() *regexp.Regexp {
initTypeCastRegex.Do(func() {
var typeCasts []string
for typeCast := range typeCastConversion {
typeCasts = append(typeCasts, regexp.QuoteMeta(typeCast))
func getSimpleStringMatchingRegex() *regexp.Regexp {
initSimpleStrMatchingRegex.Do(func() {
var simpleStrings []string
for simpleString := range simpleStringsConversion {
simpleStrings = append(simpleStrings, regexp.QuoteMeta(simpleString))
}
typeCastRegex = regexp.MustCompile(`(?i)(` + strings.Join(typeCasts, "|") + `)`)
simpleStrMatchingRegex = regexp.MustCompile(`(?i)(` + strings.Join(simpleStrings, "|") + `)`)
})
return typeCastRegex
return simpleStrMatchingRegex
}

// This function will replace all type casts in the query with the corresponding type cast in the typeCastConversion map.
func ConvertTypeCast(sql string) string {
return getTypeCastRegex().ReplaceAllStringFunc(sql, func(m string) string {
return typeCastConversion[strings.ToLower(m)]
// This function will replace all type casts in the query with the corresponding type cast in the simpleStringsConversion map.
func SimpleStrReplacement(sql string) string {
return getSimpleStringMatchingRegex().ReplaceAllStringFunc(sql, func(m string) string {
return simpleStringsConversion[strings.ToLower(m)]
})
}

Expand Down
Loading