Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve correctness and error messages for JSON functions. #2517

Merged
merged 12 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions enginetest/queries/json_scripts.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ package queries
import (
querypb "github.com/dolthub/vitess/go/vt/proto/query"

"github.com/dolthub/go-mysql-server/sql/expression/function/json"

"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/types"
)
Expand Down Expand Up @@ -183,7 +181,7 @@ var JsonScripts = []ScriptTest{
Assertions: []ScriptTestAssertion{
{
Query: `select json_value(cast(12.34 as decimal), '$', 'json')`,
ExpectedErr: json.InvalidJsonArgument,
ExpectedErr: sql.ErrInvalidJSONArgument,
},
{
Query: `select json_type(json_value(cast(cast(12.34 as decimal) as json), '$', 'json'))`,
Expand Down Expand Up @@ -257,7 +255,7 @@ var JsonScripts = []ScriptTest{
},
{
Query: `select json_length(json_extract(x, "$.a")) from xy`,
ExpectedErrStr: "failed to extract from expression 'xy.x'; object is not map",
ExpectedErrStr: "invalid data type for JSON data in argument 1 to function json_extract; a JSON string or JSON type is required",
},
{
Query: `select json_length(json_extract(y, "$.a")) from xy`,
Expand Down
5 changes: 4 additions & 1 deletion sql/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,10 @@ var (
ErrInvalidChildType = errors.NewKind("%T: invalid child type, got %T, expected %T")

// ErrInvalidJSONText is returned when a JSON string cannot be parsed or unmarshalled
ErrInvalidJSONText = errors.NewKind("Invalid JSON text: %s")
ErrInvalidJSONText = errors.NewKind("Invalid JSON text in argument %d to function %s: \"%s\"")

// ErrInvalidJSONArgument is returned when a JSON function is called with a parameter that is not JSON or a string
ErrInvalidJSONArgument = errors.NewKind("invalid data type for JSON data in argument %d to function %s; a JSON string or JSON type is required")

// ErrDeleteRowNotFound is returned when row being deleted was not found
ErrDeleteRowNotFound = errors.NewKind("row was not found when attempting to delete")
Expand Down
3 changes: 2 additions & 1 deletion sql/expression/function/json/json_array_append.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ func (j JSONArrayAppend) IsNullable() bool {
func (j JSONArrayAppend) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
doc, err := getMutableJSONVal(ctx, row, j.doc)
if err != nil || doc == nil {
return nil, err
return nil, getJsonFunctionError("json_array_append", 1, err)

}

pairs := make([]pathValPair, 0, len(j.pathVals)/2)
Expand Down
18 changes: 13 additions & 5 deletions sql/expression/function/json/json_array_append_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package json

import (
"fmt"
"strings"
"testing"

Expand Down Expand Up @@ -53,10 +54,12 @@ func TestArrayAppend(t *testing.T) {
{f1, sql.Row{json, "$.a[0]", 4.1}, `{"a": [1, 4.1], "b": [2, 3], "c": {"d": "foo"}}`, nil},
{f1, sql.Row{json, "$.a[last]", 4.1}, `{"a": [1, 4.1], "b": [2, 3], "c": {"d": "foo"}}`, nil},
{f1, sql.Row{json, "$[0]", 4.1}, `[{"a": 1, "b": [2, 3], "c": {"d": "foo"}}, 4.1]`, nil},
{f1, sql.Row{json, "$.[0]", 4.1}, nil, ErrInvalidPath},
{f1, sql.Row{json, "foo", "test"}, nil, ErrInvalidPath},
{f1, sql.Row{json, "$.c.*", "test"}, nil, ErrPathWildcard},
{f1, sql.Row{json, "$.c.**", "test"}, nil, ErrPathWildcard},
{f1, sql.Row{json, "$.[0]", 4.1}, nil, fmt.Errorf("Invalid JSON path expression. Expected field name after '.' at character 2 of $.[0]")},
{f1, sql.Row{json, "foo", "test"}, nil, fmt.Errorf("Invalid JSON path expression. Path must start with '$'")},
{f1, sql.Row{json, "$.c.*", "test"}, nil, fmt.Errorf("Invalid JSON path expression. Expected field name after '.' at character 4 of $.c.*")},
{f1, sql.Row{json, "$.c.**", "test"}, nil, fmt.Errorf("Invalid JSON path expression. Expected field name after '.' at character 4 of $.c.**")},
{f1, sql.Row{1, "$", "test"}, nil, sql.ErrInvalidJSONArgument.New(1, "json_array_append")},
{f1, sql.Row{`}`, "$", "test"}, nil, sql.ErrInvalidJSONText.New(1, "json_array_append", `}`)},
{f1, sql.Row{json, "$", 10.1}, `[{"a": 1, "b": [2, 3], "c": {"d": "foo"}}, 10.1]`, nil},
{f1, sql.Row{nil, "$", 42.7}, nil, nil},
{f1, sql.Row{json, nil, 10}, nil, nil},
Expand Down Expand Up @@ -101,7 +104,12 @@ func TestArrayAppend(t *testing.T) {
req.Equal(expect, result)
} else {
req.Nil(result)
req.Error(tstC.err, err)
if tstC.err == nil {
req.NoError(err)
} else {
req.Error(err)
req.Equal(tstC.err.Error(), err.Error())
}
}
})
}
Expand Down
2 changes: 1 addition & 1 deletion sql/expression/function/json/json_array_insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func (j JSONArrayInsert) IsNullable() bool {
func (j JSONArrayInsert) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
doc, err := getMutableJSONVal(ctx, row, j.doc)
if err != nil || doc == nil {
return nil, err
return nil, getJsonFunctionError("json_array_insert", 1, err)
}

pairs := make([]pathValPair, 0, len(j.pathVals)/2)
Expand Down
16 changes: 10 additions & 6 deletions sql/expression/function/json/json_array_insert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,16 @@ func TestArrayInsert(t *testing.T) {
{f1, sql.Row{json, "$.c.d", "test"}, nil, fmt.Errorf("A path expression is not a path to a cell in an array at character 5 of $.c.d")},
{f2, sql.Row{json, "$.b[0]", 4.1, "$.c.d", "test"}, nil, fmt.Errorf("A path expression is not a path to a cell in an array at character 5 of $.c.d")},
{f1, sql.Row{json, "$.b[5]", 4.1}, `{"a": 1, "b": [2, 3, 4.1], "c": {"d": "foo"}}`, nil},
{f1, sql.Row{json, "$.b.c", 4}, nil, fmt.Errorf("A path expression is not a path to a cell in an array at character 5 of $.b.c")},
{f1, sql.Row{json, "$.b.c", 4}, nil, fmt.Errorf("A path expression is not a path to a cell in an array at character 4 of $.b.c")},
{f1, sql.Row{json, "$.a[0]", 4.1}, json, nil},
{f1, sql.Row{json, "$[0]", 4.1}, json, nil},
{f1, sql.Row{json, "$.[0]", 4.1}, nil, ErrInvalidPath},
{f1, sql.Row{json, "foo", "test"}, nil, ErrInvalidPath},
{f1, sql.Row{json, "$.c.*", "test"}, nil, ErrPathWildcard},
{f1, sql.Row{json, "$.c.**", "test"}, nil, ErrPathWildcard},
{f1, sql.Row{json, "$.[0]", 4.1}, nil, fmt.Errorf("Invalid JSON path expression. Expected field name after '.' at character 2 of $.[0]")},
{f1, sql.Row{json, "foo", "test"}, nil, fmt.Errorf("Invalid JSON path expression. Path must start with '$'")},
{f1, sql.Row{json, "$.c.*", "test"}, nil, fmt.Errorf("Invalid JSON path expression. Expected field name after '.' at character 4 of $.c.*")},
{f1, sql.Row{json, "$.c.**", "test"}, nil, fmt.Errorf("Invalid JSON path expression. Expected field name after '.' at character 4 of $.c.**")},
{f1, sql.Row{1, "$", "test"}, nil, sql.ErrInvalidJSONArgument.New(1, "json_array_insert")},
{f1, sql.Row{`}`, "$", "test"}, nil, sql.ErrInvalidJSONText.New(1, "json_array_insert", `}`)},

{f1, sql.Row{json, "$", 10.1}, nil, fmt.Errorf("Path expression is not a path to a cell in an array: $")},
{f1, sql.Row{nil, "$", 42.7}, nil, nil},
{f1, sql.Row{json, nil, 10}, nil, nil},
Expand Down Expand Up @@ -103,7 +106,8 @@ func TestArrayInsert(t *testing.T) {
req.Equal(expect, result)
} else {
req.Nil(result)
req.Error(tstC.err, err)
req.Error(err)
req.Equal(tstC.err.Error(), err.Error())
}
})
}
Expand Down
74 changes: 45 additions & 29 deletions sql/expression/function/json/json_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package json

import (
goJson "encoding/json"
"fmt"

"github.com/dolthub/go-mysql-server/sql"
Expand All @@ -24,61 +25,75 @@ import (
var ErrInvalidPath = fmt.Errorf("Invalid JSON path expression")
var ErrPathWildcard = fmt.Errorf("Path expressions may not contain the * and ** tokens")

type invalidJson string

var _ error = invalidJson("")

func (err invalidJson) Error() string {
return "invalid json"
}

// getMutableJSONVal returns a JSONValue from the given row and expression. The underling value is deeply copied so that
// you are free to use the mutation functions on the returned value.
// nil will be returned only if the inputs are nil. This will not return an error, so callers must check.
func getMutableJSONVal(ctx *sql.Context, row sql.Row, json sql.Expression) (types.MutableJSON, error) {
doc, err := getJSONDocumentFromRow(ctx, row, json)
if err != nil || doc == nil || doc.Val == nil {
if err != nil || doc == nil {
return nil, err
}

mutable := types.DeepCopyJson(doc.Val)
val, err := doc.ToInterface()
if err != nil {
return nil, err
}
mutable := types.DeepCopyJson(val)
return types.JSONDocument{Val: mutable}, nil
}

// getSearchableJSONVal returns a SearchableJSONValue from the given row and expression. The underling value is not copied
// so it is intended to be used for read-only operations.
// nil will be returned only if the inputs are nil. This will not return an error, so callers must check.
func getSearchableJSONVal(ctx *sql.Context, row sql.Row, json sql.Expression) (sql.JSONWrapper, error) {
doc, err := getJSONDocumentFromRow(ctx, row, json)
if err != nil || doc == nil || doc.Val == nil {
return nil, err
}

return doc, nil
return getJSONDocumentFromRow(ctx, row, json)
}

// getJSONDocumentFromRow returns a JSONDocument from the given row and expression. Helper function only intended to be
// used by functions in this file.
func getJSONDocumentFromRow(ctx *sql.Context, row sql.Row, json sql.Expression) (*types.JSONDocument, error) {
func getJSONDocumentFromRow(ctx *sql.Context, row sql.Row, json sql.Expression) (sql.JSONWrapper, error) {
js, err := json.Eval(ctx, row)
if err != nil || js == nil {
return nil, err
}

var converted interface{}
switch js.(type) {
case string, []interface{}, map[string]interface{}, sql.JSONWrapper:
converted, _, err = types.JSON.Convert(js)
var jsonData interface{}

switch jsType := js.(type) {
case string:
// When coercing a string into a JSON object, don't use LazyJSONDocument; actually unmarshall it.
// This guarantees that we validate and normalize the JSON.
strData, _, err := types.LongBlob.Convert(js)
if err != nil {
return nil, sql.ErrInvalidJSONText.New(js)
return nil, err
}
if err = goJson.Unmarshal(strData.([]byte), &jsonData); err != nil {
return nil, invalidJson(jsType)
}
return types.JSONDocument{Val: jsonData}, nil
case sql.JSONWrapper:
return jsType, nil
default:
return nil, sql.ErrInvalidArgument.New(fmt.Sprintf("%v", js))
}
}

doc, ok := converted.(types.JSONDocument)
if !ok {
// This should never happen, but just in case.
val, err := js.(sql.JSONWrapper).ToInterface()
if err != nil {
return nil, err
}
doc = types.JSONDocument{Val: val}
func getJsonFunctionError(functionName string, argumentPosition int, err error) error {
if sql.ErrInvalidArgument.Is(err) {
return sql.ErrInvalidJSONArgument.New(argumentPosition, functionName)
}

return &doc, nil
if ij, ok := err.(invalidJson); ok {
return sql.ErrInvalidJSONText.New(argumentPosition, functionName, string(ij))
}
return err
}

// pathValPair is a helper struct for use by functions which take json paths paired with a json value. eg. JSON_SET, JSON_INSERT, etc.
Expand All @@ -88,18 +103,19 @@ type pathValPair struct {
}

// buildPath builds a path from the given row and expression
func buildPath(ctx *sql.Context, pathExp sql.Expression, row sql.Row) (interface{}, error) {
func buildPath(ctx *sql.Context, pathExp sql.Expression, row sql.Row) (*string, error) {
path, err := pathExp.Eval(ctx, row)
if err != nil {
return nil, err
}
if path == nil {
return nil, nil
}
if _, ok := path.(string); !ok {
return "", ErrInvalidPath
if s, ok := path.(string); ok {
return &s, nil
} else {
return nil, ErrInvalidPath
}
return path.(string), nil
}

// buildPathValue builds a pathValPair from the given row and expressions. This is a common pattern in json methods to have
Expand All @@ -122,5 +138,5 @@ func buildPathValue(ctx *sql.Context, pathExp sql.Expression, valExp sql.Express
jsonVal = types.JSONDocument{Val: val}
}

return &pathValPair{path.(string), jsonVal}, nil
return &pathValPair{*path, jsonVal}, nil
}
4 changes: 2 additions & 2 deletions sql/expression/function/json/json_contains.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,15 +120,15 @@ func (j *JSONContains) IsNullable() bool {
func (j *JSONContains) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
target, err := getSearchableJSONVal(ctx, row, j.JSONTarget)
if err != nil {
return nil, err
return nil, getJsonFunctionError("json_contains", 1, err)
}
if target == nil {
return nil, nil
}

candidate, err := getSearchableJSONVal(ctx, row, j.JSONCandidate)
if err != nil {
return nil, err
return nil, getJsonFunctionError("json_contains", 2, err)
}
if candidate == nil {
return nil, nil
Expand Down
2 changes: 1 addition & 1 deletion sql/expression/function/json/json_contains_path.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ type JSONContainsPath struct {
func (j JSONContainsPath) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
target, err := getSearchableJSONVal(ctx, row, j.doc)
if err != nil || target == nil {
return nil, err
return nil, getJsonFunctionError("json_contains_path", 1, err)
}

oneOrAll, err := j.all.Eval(ctx, row)
Expand Down
11 changes: 7 additions & 4 deletions sql/expression/function/json/json_contains_path_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,20 +92,23 @@ func TestJSONContainsPath(t *testing.T) {
{twoPath, sql.Row{`{"a": 1}`, "all", `$.x`, nil}, false, nil}, // Match MySQL behavior, not docs.
{twoPath, sql.Row{`{"a": 1}`, `all`, `$.a`, nil}, nil, nil},

// JSON NULL documents do NOT result in NULL output.
{onePath, sql.Row{`null`, `all`, `$.a`}, false, nil},

// Error cases
{onePath, sql.Row{`{"a": 1}`, `None`, `$.a`}, nil, errors.New("The oneOrAll argument to json_contains_path may take these values: 'one' or 'all'")},
{onePath, sql.Row{`{"a": 1`, `One`, `$.a`}, nil, errors.New(`Invalid JSON text: {"a": 1`)},
{onePath, sql.Row{`{"a": 1`, `One`, `$.a`}, nil, sql.ErrInvalidJSONText.New(1, "json_contains_path", `{"a": 1`)},
{threePath, sql.Row{`{"a": 1, "b": 2, "c": {"d": {"e" : 42}}}`, `one`, 42, `$.c.d.e`, `$.x`}, nil, errors.New(`Invalid JSON path expression. Path must start with '$', but received: '42'`)},
}

for _, testcase := range testCases {
t.Run(testcase.fCall.String(), func(t *testing.T) {
require := require.New(t)
result, err := testcase.fCall.Eval(sql.NewEmptyContext(), testcase.input)
if testcase.err == nil {
require.NoError(err)
if testcase.err != nil {
require.ErrorContainsf(err, testcase.err.Error(), "Expected error \"%v\" but received \"%v\"", testcase.err, err)
} else {
require.Equal(err.Error(), testcase.err.Error())
require.NoError(err)
}

require.Equal(testcase.expected, result)
Expand Down
26 changes: 20 additions & 6 deletions sql/expression/function/json/json_contains_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,15 @@ func TestJSONContains(t *testing.T) {
{f2, sql.Row{`{"a": [1, [2, 3], 4], "b": {"c": "foo", "d": true}}`, `"foo"`}, false, nil},
{f2, sql.Row{"{\"a\": {\"foo\": [1, 2, 3]}}", "{\"a\": {\"foo\": [1]}}"}, true, nil},
{f2, sql.Row{"{\"a\": {\"foo\": [1, 2, 3]}}", "{\"foo\": [1]}"}, false, nil},
{f2, sql.Row{`null`, `null`}, true, nil},
{f2, sql.Row{`null`, `1`}, false, nil},

// Path Tests
{f, sql.Row{json, json, "FOO"}, nil, errors.New("Invalid JSON path expression. Path must start with '$', but received: 'FOO'")},
{f, sql.Row{1, nil, "$.a"}, nil, errors.New("Invalid argument to 1")},
{f, sql.Row{json, 2, "$.e[0][*]"}, nil, errors.New("Invalid argument to 2")},
{f, sql.Row{1, nil, "$.a"}, nil, sql.ErrInvalidJSONArgument.New(1, "json_contains")},
{f, sql.Row{`{"a"`, nil, "$.a"}, nil, sql.ErrInvalidJSONText.New(1, "json_contains", `{"a"`)},
{f, sql.Row{json, 2, "$.e[0][*]"}, nil, sql.ErrInvalidJSONArgument.New(2, "json_contains")},
{f, sql.Row{json, `}"a"`, "$.e[0][*]"}, nil, sql.ErrInvalidJSONText.New(2, "json_contains", `}"a"`)},
{f, sql.Row{nil, json, "$.b.c"}, nil, nil},
{f, sql.Row{json, nil, "$.b.c"}, nil, nil},
{f, sql.Row{json, json, "$.foo"}, nil, nil},
Expand All @@ -109,6 +113,15 @@ func TestJSONContains(t *testing.T) {
{f, sql.Row{json, goodMap, "$.e"}, false, nil}, // The path statement selects an array, which does not contain goodMap
{f, sql.Row{json, badMap, "$"}, false, nil}, // false due to key name difference
{f, sql.Row{json, goodMap, "$"}, true, nil},
// The only allowed path for a scalar document is "$"
{f, sql.Row{`null`, `10`, "$"}, false, nil},
{f, sql.Row{`null`, `null`, "$"}, true, nil},
{f, sql.Row{`10`, `10`, "$"}, true, nil},
{f, sql.Row{`10`, `null`, "$"}, false, nil},
{f, sql.Row{`null`, `10`, "$.b"}, nil, nil},
{f, sql.Row{`10`, `null`, "$.b"}, nil, nil},
// JSON_CONTAINS can successfully look up JSON NULL with a path
{f, sql.Row{`{"a": null}`, `null`, "$.a"}, true, nil},

// Miscellaneous Tests
{f2, sql.Row{json, `[1, 2]`}, false, nil}, // When testing containment against a map, scalars and arrays always return false
Expand All @@ -117,9 +130,9 @@ func TestJSONContains(t *testing.T) {
{f2, sql.Row{`["apple", "orange", "banana"]`, `"orange"`}, true, nil},
{f2, sql.Row{`"hello"`, `"hello"`}, true, nil},
{f2, sql.Row{"{}", "{}"}, true, nil},
{f2, sql.Row{"hello", "hello"}, nil, sql.ErrInvalidJSONText.New("hello")},
{f2, sql.Row{"[1,2", "[1]"}, nil, sql.ErrInvalidJSONText.New("[1,2")},
{f2, sql.Row{"[1,2]", "[1"}, nil, sql.ErrInvalidJSONText.New("[1")},
{f2, sql.Row{"hello", "hello"}, nil, sql.ErrInvalidJSONText.New(1, "json_contains", "hello")},
{f2, sql.Row{"[1,2", "[1]"}, nil, sql.ErrInvalidJSONText.New(1, "json_contains", "[1,2")},
{f2, sql.Row{"[1,2]", "[1"}, nil, sql.ErrInvalidJSONText.New(2, "json_contains", "[1")},
}

for _, tt := range testCases {
Expand All @@ -129,7 +142,8 @@ func TestJSONContains(t *testing.T) {
if tt.err == nil {
require.NoError(err)
} else {
require.Equal(err.Error(), tt.err.Error())
require.Error(err)
require.Equal(tt.err.Error(), err.Error())
}

require.Equal(tt.expected, result)
Expand Down
8 changes: 6 additions & 2 deletions sql/expression/function/json/json_depth.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,17 @@ func (j *JSONDepth) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {

doc, err := getJSONDocumentFromRow(ctx, row, j.JSON)
if err != nil {
return nil, err
return nil, getJsonFunctionError("json_depth", 1, err)
}
if doc == nil {
return nil, nil
}

d, err := depth(doc.Val)
val, err := doc.ToInterface()
if err != nil {
return nil, err
}
d, err := depth(val)
if err != nil {
return nil, err
}
Expand Down
Loading
Loading