diff --git a/enginetest/queries/json_scripts.go b/enginetest/queries/json_scripts.go index 5944cf8743..f742321c96 100644 --- a/enginetest/queries/json_scripts.go +++ b/enginetest/queries/json_scripts.go @@ -17,8 +17,6 @@ package queries import ( querypb "github.com/dolthub/vitess/go/vt/proto/query" - "github.com/dolthub/go-mysql-server/sql/expression/function/json" - "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/types" ) @@ -183,7 +181,7 @@ var JsonScripts = []ScriptTest{ Assertions: []ScriptTestAssertion{ { Query: `select json_value(cast(12.34 as decimal), '$', 'json')`, - ExpectedErr: json.InvalidJsonArgument, + ExpectedErr: sql.ErrInvalidJSONArgument, }, { Query: `select json_type(json_value(cast(cast(12.34 as decimal) as json), '$', 'json'))`, @@ -257,7 +255,7 @@ var JsonScripts = []ScriptTest{ }, { Query: `select json_length(json_extract(x, "$.a")) from xy`, - ExpectedErrStr: "failed to extract from expression 'xy.x'; object is not map", + ExpectedErrStr: "invalid data type for JSON data in argument 1 to function json_extract; a JSON string or JSON type is required", }, { Query: `select json_length(json_extract(y, "$.a")) from xy`, diff --git a/sql/errors.go b/sql/errors.go index 4c8f059b41..b4fccdb061 100644 --- a/sql/errors.go +++ b/sql/errors.go @@ -108,7 +108,10 @@ var ( ErrInvalidChildType = errors.NewKind("%T: invalid child type, got %T, expected %T") // ErrInvalidJSONText is returned when a JSON string cannot be parsed or unmarshalled - ErrInvalidJSONText = errors.NewKind("Invalid JSON text: %s") + ErrInvalidJSONText = errors.NewKind("Invalid JSON text in argument %d to function %s: \"%s\"") + + // ErrInvalidJSONArgument is returned when a JSON function is called with a parameter that is not JSON or a string + ErrInvalidJSONArgument = errors.NewKind("invalid data type for JSON data in argument %d to function %s; a JSON string or JSON type is required") // ErrDeleteRowNotFound is returned when row being deleted was not found ErrDeleteRowNotFound = errors.NewKind("row was not found when attempting to delete") diff --git a/sql/expression/function/json/json_array_append.go b/sql/expression/function/json/json_array_append.go index fc49c09705..daba7ef3bb 100644 --- a/sql/expression/function/json/json_array_append.go +++ b/sql/expression/function/json/json_array_append.go @@ -73,7 +73,8 @@ func (j JSONArrayAppend) IsNullable() bool { func (j JSONArrayAppend) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { doc, err := getMutableJSONVal(ctx, row, j.doc) if err != nil || doc == nil { - return nil, err + return nil, getJsonFunctionError("json_array_append", 1, err) + } pairs := make([]pathValPair, 0, len(j.pathVals)/2) diff --git a/sql/expression/function/json/json_array_append_test.go b/sql/expression/function/json/json_array_append_test.go index b151e3c2d8..4e8d47aa0e 100644 --- a/sql/expression/function/json/json_array_append_test.go +++ b/sql/expression/function/json/json_array_append_test.go @@ -15,6 +15,7 @@ package json import ( + "fmt" "strings" "testing" @@ -53,10 +54,12 @@ func TestArrayAppend(t *testing.T) { {f1, sql.Row{json, "$.a[0]", 4.1}, `{"a": [1, 4.1], "b": [2, 3], "c": {"d": "foo"}}`, nil}, {f1, sql.Row{json, "$.a[last]", 4.1}, `{"a": [1, 4.1], "b": [2, 3], "c": {"d": "foo"}}`, nil}, {f1, sql.Row{json, "$[0]", 4.1}, `[{"a": 1, "b": [2, 3], "c": {"d": "foo"}}, 4.1]`, nil}, - {f1, sql.Row{json, "$.[0]", 4.1}, nil, ErrInvalidPath}, - {f1, sql.Row{json, "foo", "test"}, nil, ErrInvalidPath}, - {f1, sql.Row{json, "$.c.*", "test"}, nil, ErrPathWildcard}, - {f1, sql.Row{json, "$.c.**", "test"}, nil, ErrPathWildcard}, + {f1, sql.Row{json, "$.[0]", 4.1}, nil, fmt.Errorf("Invalid JSON path expression. Expected field name after '.' at character 2 of $.[0]")}, + {f1, sql.Row{json, "foo", "test"}, nil, fmt.Errorf("Invalid JSON path expression. Path must start with '$'")}, + {f1, sql.Row{json, "$.c.*", "test"}, nil, fmt.Errorf("Invalid JSON path expression. Expected field name after '.' at character 4 of $.c.*")}, + {f1, sql.Row{json, "$.c.**", "test"}, nil, fmt.Errorf("Invalid JSON path expression. Expected field name after '.' at character 4 of $.c.**")}, + {f1, sql.Row{1, "$", "test"}, nil, sql.ErrInvalidJSONArgument.New(1, "json_array_append")}, + {f1, sql.Row{`}`, "$", "test"}, nil, sql.ErrInvalidJSONText.New(1, "json_array_append", `}`)}, {f1, sql.Row{json, "$", 10.1}, `[{"a": 1, "b": [2, 3], "c": {"d": "foo"}}, 10.1]`, nil}, {f1, sql.Row{nil, "$", 42.7}, nil, nil}, {f1, sql.Row{json, nil, 10}, nil, nil}, @@ -101,7 +104,12 @@ func TestArrayAppend(t *testing.T) { req.Equal(expect, result) } else { req.Nil(result) - req.Error(tstC.err, err) + if tstC.err == nil { + req.NoError(err) + } else { + req.Error(err) + req.Equal(tstC.err.Error(), err.Error()) + } } }) } diff --git a/sql/expression/function/json/json_array_insert.go b/sql/expression/function/json/json_array_insert.go index fd92885bb8..74623da9b1 100644 --- a/sql/expression/function/json/json_array_insert.go +++ b/sql/expression/function/json/json_array_insert.go @@ -75,7 +75,7 @@ func (j JSONArrayInsert) IsNullable() bool { func (j JSONArrayInsert) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { doc, err := getMutableJSONVal(ctx, row, j.doc) if err != nil || doc == nil { - return nil, err + return nil, getJsonFunctionError("json_array_insert", 1, err) } pairs := make([]pathValPair, 0, len(j.pathVals)/2) diff --git a/sql/expression/function/json/json_array_insert_test.go b/sql/expression/function/json/json_array_insert_test.go index 233e995d3c..44703cc15e 100644 --- a/sql/expression/function/json/json_array_insert_test.go +++ b/sql/expression/function/json/json_array_insert_test.go @@ -52,13 +52,16 @@ func TestArrayInsert(t *testing.T) { {f1, sql.Row{json, "$.c.d", "test"}, nil, fmt.Errorf("A path expression is not a path to a cell in an array at character 5 of $.c.d")}, {f2, sql.Row{json, "$.b[0]", 4.1, "$.c.d", "test"}, nil, fmt.Errorf("A path expression is not a path to a cell in an array at character 5 of $.c.d")}, {f1, sql.Row{json, "$.b[5]", 4.1}, `{"a": 1, "b": [2, 3, 4.1], "c": {"d": "foo"}}`, nil}, - {f1, sql.Row{json, "$.b.c", 4}, nil, fmt.Errorf("A path expression is not a path to a cell in an array at character 5 of $.b.c")}, + {f1, sql.Row{json, "$.b.c", 4}, nil, fmt.Errorf("A path expression is not a path to a cell in an array at character 4 of $.b.c")}, {f1, sql.Row{json, "$.a[0]", 4.1}, json, nil}, {f1, sql.Row{json, "$[0]", 4.1}, json, nil}, - {f1, sql.Row{json, "$.[0]", 4.1}, nil, ErrInvalidPath}, - {f1, sql.Row{json, "foo", "test"}, nil, ErrInvalidPath}, - {f1, sql.Row{json, "$.c.*", "test"}, nil, ErrPathWildcard}, - {f1, sql.Row{json, "$.c.**", "test"}, nil, ErrPathWildcard}, + {f1, sql.Row{json, "$.[0]", 4.1}, nil, fmt.Errorf("Invalid JSON path expression. Expected field name after '.' at character 2 of $.[0]")}, + {f1, sql.Row{json, "foo", "test"}, nil, fmt.Errorf("Invalid JSON path expression. Path must start with '$'")}, + {f1, sql.Row{json, "$.c.*", "test"}, nil, fmt.Errorf("Invalid JSON path expression. Expected field name after '.' at character 4 of $.c.*")}, + {f1, sql.Row{json, "$.c.**", "test"}, nil, fmt.Errorf("Invalid JSON path expression. Expected field name after '.' at character 4 of $.c.**")}, + {f1, sql.Row{1, "$", "test"}, nil, sql.ErrInvalidJSONArgument.New(1, "json_array_insert")}, + {f1, sql.Row{`}`, "$", "test"}, nil, sql.ErrInvalidJSONText.New(1, "json_array_insert", `}`)}, + {f1, sql.Row{json, "$", 10.1}, nil, fmt.Errorf("Path expression is not a path to a cell in an array: $")}, {f1, sql.Row{nil, "$", 42.7}, nil, nil}, {f1, sql.Row{json, nil, 10}, nil, nil}, @@ -103,7 +106,8 @@ func TestArrayInsert(t *testing.T) { req.Equal(expect, result) } else { req.Nil(result) - req.Error(tstC.err, err) + req.Error(err) + req.Equal(tstC.err.Error(), err.Error()) } }) } diff --git a/sql/expression/function/json/json_common.go b/sql/expression/function/json/json_common.go index 222416ab73..261d12ae14 100644 --- a/sql/expression/function/json/json_common.go +++ b/sql/expression/function/json/json_common.go @@ -15,6 +15,7 @@ package json import ( + goJson "encoding/json" "fmt" "github.com/dolthub/go-mysql-server/sql" @@ -24,16 +25,28 @@ import ( var ErrInvalidPath = fmt.Errorf("Invalid JSON path expression") var ErrPathWildcard = fmt.Errorf("Path expressions may not contain the * and ** tokens") +type invalidJson string + +var _ error = invalidJson("") + +func (err invalidJson) Error() string { + return "invalid json" +} + // getMutableJSONVal returns a JSONValue from the given row and expression. The underling value is deeply copied so that // you are free to use the mutation functions on the returned value. // nil will be returned only if the inputs are nil. This will not return an error, so callers must check. func getMutableJSONVal(ctx *sql.Context, row sql.Row, json sql.Expression) (types.MutableJSON, error) { doc, err := getJSONDocumentFromRow(ctx, row, json) - if err != nil || doc == nil || doc.Val == nil { + if err != nil || doc == nil { return nil, err } - mutable := types.DeepCopyJson(doc.Val) + val, err := doc.ToInterface() + if err != nil { + return nil, err + } + mutable := types.DeepCopyJson(val) return types.JSONDocument{Val: mutable}, nil } @@ -41,44 +54,46 @@ func getMutableJSONVal(ctx *sql.Context, row sql.Row, json sql.Expression) (type // so it is intended to be used for read-only operations. // nil will be returned only if the inputs are nil. This will not return an error, so callers must check. func getSearchableJSONVal(ctx *sql.Context, row sql.Row, json sql.Expression) (sql.JSONWrapper, error) { - doc, err := getJSONDocumentFromRow(ctx, row, json) - if err != nil || doc == nil || doc.Val == nil { - return nil, err - } - - return doc, nil + return getJSONDocumentFromRow(ctx, row, json) } // getJSONDocumentFromRow returns a JSONDocument from the given row and expression. Helper function only intended to be // used by functions in this file. -func getJSONDocumentFromRow(ctx *sql.Context, row sql.Row, json sql.Expression) (*types.JSONDocument, error) { +func getJSONDocumentFromRow(ctx *sql.Context, row sql.Row, json sql.Expression) (sql.JSONWrapper, error) { js, err := json.Eval(ctx, row) if err != nil || js == nil { return nil, err } - var converted interface{} - switch js.(type) { - case string, []interface{}, map[string]interface{}, sql.JSONWrapper: - converted, _, err = types.JSON.Convert(js) + var jsonData interface{} + + switch jsType := js.(type) { + case string: + // When coercing a string into a JSON object, don't use LazyJSONDocument; actually unmarshall it. + // This guarantees that we validate and normalize the JSON. + strData, _, err := types.LongBlob.Convert(js) if err != nil { - return nil, sql.ErrInvalidJSONText.New(js) + return nil, err + } + if err = goJson.Unmarshal(strData.([]byte), &jsonData); err != nil { + return nil, invalidJson(jsType) } + return types.JSONDocument{Val: jsonData}, nil + case sql.JSONWrapper: + return jsType, nil default: return nil, sql.ErrInvalidArgument.New(fmt.Sprintf("%v", js)) } +} - doc, ok := converted.(types.JSONDocument) - if !ok { - // This should never happen, but just in case. - val, err := js.(sql.JSONWrapper).ToInterface() - if err != nil { - return nil, err - } - doc = types.JSONDocument{Val: val} +func getJsonFunctionError(functionName string, argumentPosition int, err error) error { + if sql.ErrInvalidArgument.Is(err) { + return sql.ErrInvalidJSONArgument.New(argumentPosition, functionName) } - - return &doc, nil + if ij, ok := err.(invalidJson); ok { + return sql.ErrInvalidJSONText.New(argumentPosition, functionName, string(ij)) + } + return err } // pathValPair is a helper struct for use by functions which take json paths paired with a json value. eg. JSON_SET, JSON_INSERT, etc. @@ -88,7 +103,7 @@ type pathValPair struct { } // buildPath builds a path from the given row and expression -func buildPath(ctx *sql.Context, pathExp sql.Expression, row sql.Row) (interface{}, error) { +func buildPath(ctx *sql.Context, pathExp sql.Expression, row sql.Row) (*string, error) { path, err := pathExp.Eval(ctx, row) if err != nil { return nil, err @@ -96,10 +111,11 @@ func buildPath(ctx *sql.Context, pathExp sql.Expression, row sql.Row) (interface if path == nil { return nil, nil } - if _, ok := path.(string); !ok { - return "", ErrInvalidPath + if s, ok := path.(string); ok { + return &s, nil + } else { + return nil, ErrInvalidPath } - return path.(string), nil } // buildPathValue builds a pathValPair from the given row and expressions. This is a common pattern in json methods to have @@ -122,5 +138,5 @@ func buildPathValue(ctx *sql.Context, pathExp sql.Expression, valExp sql.Express jsonVal = types.JSONDocument{Val: val} } - return &pathValPair{path.(string), jsonVal}, nil + return &pathValPair{*path, jsonVal}, nil } diff --git a/sql/expression/function/json/json_contains.go b/sql/expression/function/json/json_contains.go index 64852c7b80..6a67ed95ec 100644 --- a/sql/expression/function/json/json_contains.go +++ b/sql/expression/function/json/json_contains.go @@ -120,7 +120,7 @@ func (j *JSONContains) IsNullable() bool { func (j *JSONContains) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { target, err := getSearchableJSONVal(ctx, row, j.JSONTarget) if err != nil { - return nil, err + return nil, getJsonFunctionError("json_contains", 1, err) } if target == nil { return nil, nil @@ -128,7 +128,7 @@ func (j *JSONContains) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) candidate, err := getSearchableJSONVal(ctx, row, j.JSONCandidate) if err != nil { - return nil, err + return nil, getJsonFunctionError("json_contains", 2, err) } if candidate == nil { return nil, nil diff --git a/sql/expression/function/json/json_contains_path.go b/sql/expression/function/json/json_contains_path.go index 6ab7480a09..a911584d34 100644 --- a/sql/expression/function/json/json_contains_path.go +++ b/sql/expression/function/json/json_contains_path.go @@ -47,7 +47,7 @@ type JSONContainsPath struct { func (j JSONContainsPath) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { target, err := getSearchableJSONVal(ctx, row, j.doc) if err != nil || target == nil { - return nil, err + return nil, getJsonFunctionError("json_contains_path", 1, err) } oneOrAll, err := j.all.Eval(ctx, row) diff --git a/sql/expression/function/json/json_contains_path_test.go b/sql/expression/function/json/json_contains_path_test.go index 2a141f27d1..965ebcc4e8 100644 --- a/sql/expression/function/json/json_contains_path_test.go +++ b/sql/expression/function/json/json_contains_path_test.go @@ -92,9 +92,12 @@ func TestJSONContainsPath(t *testing.T) { {twoPath, sql.Row{`{"a": 1}`, "all", `$.x`, nil}, false, nil}, // Match MySQL behavior, not docs. {twoPath, sql.Row{`{"a": 1}`, `all`, `$.a`, nil}, nil, nil}, + // JSON NULL documents do NOT result in NULL output. + {onePath, sql.Row{`null`, `all`, `$.a`}, false, nil}, + // Error cases {onePath, sql.Row{`{"a": 1}`, `None`, `$.a`}, nil, errors.New("The oneOrAll argument to json_contains_path may take these values: 'one' or 'all'")}, - {onePath, sql.Row{`{"a": 1`, `One`, `$.a`}, nil, errors.New(`Invalid JSON text: {"a": 1`)}, + {onePath, sql.Row{`{"a": 1`, `One`, `$.a`}, nil, sql.ErrInvalidJSONText.New(1, "json_contains_path", `{"a": 1`)}, {threePath, sql.Row{`{"a": 1, "b": 2, "c": {"d": {"e" : 42}}}`, `one`, 42, `$.c.d.e`, `$.x`}, nil, errors.New(`Invalid JSON path expression. Path must start with '$', but received: '42'`)}, } @@ -102,10 +105,10 @@ func TestJSONContainsPath(t *testing.T) { t.Run(testcase.fCall.String(), func(t *testing.T) { require := require.New(t) result, err := testcase.fCall.Eval(sql.NewEmptyContext(), testcase.input) - if testcase.err == nil { - require.NoError(err) + if testcase.err != nil { + require.ErrorContainsf(err, testcase.err.Error(), "Expected error \"%v\" but received \"%v\"", testcase.err, err) } else { - require.Equal(err.Error(), testcase.err.Error()) + require.NoError(err) } require.Equal(testcase.expected, result) diff --git a/sql/expression/function/json/json_contains_test.go b/sql/expression/function/json/json_contains_test.go index 4bb5d70168..9b0e07e6e1 100644 --- a/sql/expression/function/json/json_contains_test.go +++ b/sql/expression/function/json/json_contains_test.go @@ -91,11 +91,15 @@ func TestJSONContains(t *testing.T) { {f2, sql.Row{`{"a": [1, [2, 3], 4], "b": {"c": "foo", "d": true}}`, `"foo"`}, false, nil}, {f2, sql.Row{"{\"a\": {\"foo\": [1, 2, 3]}}", "{\"a\": {\"foo\": [1]}}"}, true, nil}, {f2, sql.Row{"{\"a\": {\"foo\": [1, 2, 3]}}", "{\"foo\": [1]}"}, false, nil}, + {f2, sql.Row{`null`, `null`}, true, nil}, + {f2, sql.Row{`null`, `1`}, false, nil}, // Path Tests {f, sql.Row{json, json, "FOO"}, nil, errors.New("Invalid JSON path expression. Path must start with '$', but received: 'FOO'")}, - {f, sql.Row{1, nil, "$.a"}, nil, errors.New("Invalid argument to 1")}, - {f, sql.Row{json, 2, "$.e[0][*]"}, nil, errors.New("Invalid argument to 2")}, + {f, sql.Row{1, nil, "$.a"}, nil, sql.ErrInvalidJSONArgument.New(1, "json_contains")}, + {f, sql.Row{`{"a"`, nil, "$.a"}, nil, sql.ErrInvalidJSONText.New(1, "json_contains", `{"a"`)}, + {f, sql.Row{json, 2, "$.e[0][*]"}, nil, sql.ErrInvalidJSONArgument.New(2, "json_contains")}, + {f, sql.Row{json, `}"a"`, "$.e[0][*]"}, nil, sql.ErrInvalidJSONText.New(2, "json_contains", `}"a"`)}, {f, sql.Row{nil, json, "$.b.c"}, nil, nil}, {f, sql.Row{json, nil, "$.b.c"}, nil, nil}, {f, sql.Row{json, json, "$.foo"}, nil, nil}, @@ -109,6 +113,15 @@ func TestJSONContains(t *testing.T) { {f, sql.Row{json, goodMap, "$.e"}, false, nil}, // The path statement selects an array, which does not contain goodMap {f, sql.Row{json, badMap, "$"}, false, nil}, // false due to key name difference {f, sql.Row{json, goodMap, "$"}, true, nil}, + // The only allowed path for a scalar document is "$" + {f, sql.Row{`null`, `10`, "$"}, false, nil}, + {f, sql.Row{`null`, `null`, "$"}, true, nil}, + {f, sql.Row{`10`, `10`, "$"}, true, nil}, + {f, sql.Row{`10`, `null`, "$"}, false, nil}, + {f, sql.Row{`null`, `10`, "$.b"}, nil, nil}, + {f, sql.Row{`10`, `null`, "$.b"}, nil, nil}, + // JSON_CONTAINS can successfully look up JSON NULL with a path + {f, sql.Row{`{"a": null}`, `null`, "$.a"}, true, nil}, // Miscellaneous Tests {f2, sql.Row{json, `[1, 2]`}, false, nil}, // When testing containment against a map, scalars and arrays always return false @@ -117,9 +130,9 @@ func TestJSONContains(t *testing.T) { {f2, sql.Row{`["apple", "orange", "banana"]`, `"orange"`}, true, nil}, {f2, sql.Row{`"hello"`, `"hello"`}, true, nil}, {f2, sql.Row{"{}", "{}"}, true, nil}, - {f2, sql.Row{"hello", "hello"}, nil, sql.ErrInvalidJSONText.New("hello")}, - {f2, sql.Row{"[1,2", "[1]"}, nil, sql.ErrInvalidJSONText.New("[1,2")}, - {f2, sql.Row{"[1,2]", "[1"}, nil, sql.ErrInvalidJSONText.New("[1")}, + {f2, sql.Row{"hello", "hello"}, nil, sql.ErrInvalidJSONText.New(1, "json_contains", "hello")}, + {f2, sql.Row{"[1,2", "[1]"}, nil, sql.ErrInvalidJSONText.New(1, "json_contains", "[1,2")}, + {f2, sql.Row{"[1,2]", "[1"}, nil, sql.ErrInvalidJSONText.New(2, "json_contains", "[1")}, } for _, tt := range testCases { @@ -129,7 +142,8 @@ func TestJSONContains(t *testing.T) { if tt.err == nil { require.NoError(err) } else { - require.Equal(err.Error(), tt.err.Error()) + require.Error(err) + require.Equal(tt.err.Error(), err.Error()) } require.Equal(tt.expected, result) diff --git a/sql/expression/function/json/json_depth.go b/sql/expression/function/json/json_depth.go index 5c6b0825cf..f29ede00c7 100644 --- a/sql/expression/function/json/json_depth.go +++ b/sql/expression/function/json/json_depth.go @@ -108,13 +108,17 @@ func (j *JSONDepth) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { doc, err := getJSONDocumentFromRow(ctx, row, j.JSON) if err != nil { - return nil, err + return nil, getJsonFunctionError("json_depth", 1, err) } if doc == nil { return nil, nil } - d, err := depth(doc.Val) + val, err := doc.ToInterface() + if err != nil { + return nil, err + } + d, err := depth(val) if err != nil { return nil, err } diff --git a/sql/expression/function/json/json_depth_test.go b/sql/expression/function/json/json_depth_test.go index a56998e5b7..7061b85df9 100644 --- a/sql/expression/function/json/json_depth_test.go +++ b/sql/expression/function/json/json_depth_test.go @@ -34,27 +34,27 @@ func TestJSONDepth(t *testing.T) { f sql.Expression row sql.Row exp interface{} - err bool + err error }{ { f: f1, row: sql.Row{``}, - err: true, + err: sql.ErrInvalidJSONText.New(1, "json_depth", ``), }, { f: f1, row: sql.Row{`badjson`}, - err: true, + err: sql.ErrInvalidJSONText.New(1, "json_depth", `badjson`), }, { f: f1, row: sql.Row{true}, - err: true, + err: sql.ErrInvalidJSONArgument.New(1, "json_depth"), }, { f: f1, row: sql.Row{1}, - err: true, + err: sql.ErrInvalidJSONArgument.New(1, "json_depth"), }, { @@ -157,8 +157,9 @@ func TestJSONDepth(t *testing.T) { t.Run(strings.Join(args, ", "), func(t *testing.T) { require := require.New(t) result, err := tt.f.Eval(sql.NewEmptyContext(), tt.row) - if tt.err { + if tt.err != nil { require.Error(err) + require.Equal(tt.err.Error(), err.Error()) return } require.NoError(err) diff --git a/sql/expression/function/json/json_extract.go b/sql/expression/function/json/json_extract.go index 17f8c01b87..14b269f864 100644 --- a/sql/expression/function/json/json_extract.go +++ b/sql/expression/function/json/json_extract.go @@ -76,18 +76,13 @@ func (j *JSONExtract) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { span, ctx := ctx.Span("function.JSONExtract") defer span.End() - js, err := j.JSON.Eval(ctx, row) + js, err := getSearchableJSONVal(ctx, row, j.JSON) if err != nil { - return nil, err + return nil, getJsonFunctionError("json_extract", 1, err) } - // sql NULLs, should result in sql NULLs. + // If the document is SQL NULL, the result is SQL NULL if js == nil { - return nil, err - } - - js, _, err = types.JSON.Convert(js) - if err != nil { - return nil, err + return nil, nil } searchable, ok := js.(sql.JSONWrapper) diff --git a/sql/expression/function/json/json_extract_test.go b/sql/expression/function/json/json_extract_test.go index 09bf4394aa..2a11b92af6 100644 --- a/sql/expression/function/json/json_extract_test.go +++ b/sql/expression/function/json/json_extract_test.go @@ -47,7 +47,7 @@ func TestJSONExtract(t *testing.T) { ) require.NoError(t, err) - json := map[string]interface{}{ + json := types.JSONDocument{Val: map[string]interface{}{ "a": []interface{}{float64(1), float64(2), float64(3), float64(4)}, "b": map[string]interface{}{ "c": "foo", @@ -58,13 +58,13 @@ func TestJSONExtract(t *testing.T) { []interface{}{float64(3), float64(4)}, }, "f": map[string]interface{}{ - `key.with.dots`: 0, - `key with spaces`: 1, - `key"with"dquotes`: 2, - `key'with'squotes`: 3, - `key\with\backslashes`: 4, + `key.with.dots`: float64(0), + `key with spaces`: float64(1), + `key"with"dquotes`: float64(2), + `key'with'squotes`: float64(3), + `key\with\backslashes`: float64(4), }, - } + }} testCases := []struct { f sql.Expression @@ -73,7 +73,10 @@ func TestJSONExtract(t *testing.T) { err error }{ //{f2, sql.Row{json, "FOO"}, nil, errors.New("should start with '$'")}, + {f2, sql.Row{nil, "$"}, nil, nil}, {f2, sql.Row{nil, "$.b.c"}, nil, nil}, + {f2, sql.Row{"null", "$"}, types.JSONDocument{Val: nil}, nil}, + {f2, sql.Row{"null", "$.b.c"}, nil, nil}, {f2, sql.Row{json, "$.foo"}, nil, nil}, {f2, sql.Row{json, "$.b.c"}, types.JSONDocument{Val: "foo"}, nil}, {f3, sql.Row{json, "$.b.c", "$.b.d"}, types.JSONDocument{Val: []interface{}{"foo", true}}, nil}, @@ -89,6 +92,10 @@ func TestJSONExtract(t *testing.T) { {f2, sql.Row{json, `$.f.key'with'squotes`}, types.JSONDocument{Val: float64(3)}, nil}, {f2, sql.Row{json, `$.f."key'with'squotes"`}, types.JSONDocument{Val: float64(3)}, nil}, + // Error when the document isn't JSON or a coercible string + {f2, sql.Row{1, `$.f`}, nil, sql.ErrInvalidJSONArgument.New(1, "json_extract")}, + {f2, sql.Row{`}`, `$.f`}, nil, sql.ErrInvalidJSONText.New(1, "json_extract", "}")}, + // TODO: Fix these. They work in mysql //{f2, sql.Row{json, `$.f.key\\"with\\"dquotes`}, sql.JSONDocument{Val: 2}, nil}, //{f2, sql.Row{json, `$.f.key\'with\'squotes`}, sql.JSONDocument{Val: 3}, nil}, @@ -105,12 +112,11 @@ func TestJSONExtract(t *testing.T) { t.Run(tt.f.String()+"."+strings.Join(paths, ","), func(t *testing.T) { require := require.New(t) result, err := tt.f.Eval(sql.NewEmptyContext(), tt.row) - if tt.err == nil { - require.NoError(err) + if tt.err != nil { + require.ErrorContainsf(err, tt.err.Error(), "Expected error \"%v\" but received \"%v\"", tt.err, err) } else { - require.Error(tt.err, err) + require.NoError(err) } - require.Equal(tt.expected, result) }) } diff --git a/sql/expression/function/json/json_insert.go b/sql/expression/function/json/json_insert.go index 9427466071..9a67a6d352 100644 --- a/sql/expression/function/json/json_insert.go +++ b/sql/expression/function/json/json_insert.go @@ -80,7 +80,7 @@ func (j JSONInsert) IsNullable() bool { func (j JSONInsert) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { doc, err := getMutableJSONVal(ctx, row, j.doc) if err != nil || doc == nil { - return nil, err + return nil, getJsonFunctionError("json_insert", 1, err) } pairs := make([]pathValPair, 0, len(j.pathVals)/2) diff --git a/sql/expression/function/json/json_insert_test.go b/sql/expression/function/json/json_insert_test.go index dfa6614b2f..12f857ee69 100644 --- a/sql/expression/function/json/json_insert_test.go +++ b/sql/expression/function/json/json_insert_test.go @@ -15,6 +15,7 @@ package json import ( + "fmt" "strings" "testing" @@ -40,25 +41,28 @@ func TestInsert(t *testing.T) { expected interface{} err error }{ - {f1, sql.Row{json, "$.a", 10.1}, json, nil}, // insert existing does nothing - {f1, sql.Row{json, "$.e", "new"}, `{"a": 1, "b": [2, 3], "c": {"d": "foo"},"e":"new"}`, nil}, // insert new - {f1, sql.Row{json, "$.c.d", "test"}, json, nil}, // insert existing nested does nothing - {f2, sql.Row{json, "$.a", 10.1, "$.e", "new"}, `{"a": 1, "b": [2, 3], "c": {"d": "foo"},"e":"new"}`, nil}, // insert multiple, one change. - {f1, sql.Row{json, "$.a.e", "test"}, json, nil}, // insert nested does nothing - {f1, sql.Row{json, "$.c.e", "test"}, `{"a": 1, "b": [2, 3], "c": {"d": "foo","e":"test"}}`, nil}, // insert nested in existing struct - {f1, sql.Row{json, "$.c[5]", 4.1}, `{"a": 1, "b": [2, 3], "c": [{"d": "foo"}, 4.1]}`, nil}, // insert struct with indexing out of range - {f1, sql.Row{json, "$.b[0]", 4.1}, json, nil}, // insert element in array does nothing - {f1, sql.Row{json, "$.b[5]", 4.1}, `{"a": 1, "b": [2, 3, 4.1], "c": {"d": "foo"}}`, nil}, // insert element in array out of range - {f1, sql.Row{json, "$.b.c", 4}, json, nil}, // insert nested in array does nothing - {f1, sql.Row{json, "$.a[0]", 4.1}, json, nil}, // struct as array does nothing - {f1, sql.Row{json, "$[0]", 4.1}, json, nil}, // struct does nothing. - {f1, sql.Row{json, "$.[0]", 4.1}, nil, ErrInvalidPath}, // improper struct indexing - {f1, sql.Row{json, "foo", "test"}, nil, ErrInvalidPath}, // invalid path - {f1, sql.Row{json, "$.c.*", "test"}, nil, ErrPathWildcard}, // path contains * wildcard - {f1, sql.Row{json, "$.c.**", "test"}, nil, ErrPathWildcard}, // path contains ** wildcard - {f1, sql.Row{json, "$", 10.1}, json, nil}, // whole document no opt - {f1, sql.Row{nil, "$", 42.7}, nil, nil}, // null document returns null - {f1, sql.Row{json, nil, 10}, nil, nil}, // if any path is null, return null + {f1, sql.Row{json, "$.a", 10.1}, json, nil}, // insert existing does nothing + {f1, sql.Row{json, "$.e", "new"}, `{"a": 1, "b": [2, 3], "c": {"d": "foo"},"e":"new"}`, nil}, // insert new + {f1, sql.Row{json, "$.c.d", "test"}, json, nil}, // insert existing nested does nothing + {f2, sql.Row{json, "$.a", 10.1, "$.e", "new"}, `{"a": 1, "b": [2, 3], "c": {"d": "foo"},"e":"new"}`, nil}, // insert multiple, one change. + {f1, sql.Row{json, "$.a.e", "test"}, json, nil}, // insert nested does nothing + {f1, sql.Row{json, "$.c.e", "test"}, `{"a": 1, "b": [2, 3], "c": {"d": "foo","e":"test"}}`, nil}, // insert nested in existing struct + {f1, sql.Row{json, "$.c[5]", 4.1}, `{"a": 1, "b": [2, 3], "c": [{"d": "foo"}, 4.1]}`, nil}, // insert struct with indexing out of range + {f1, sql.Row{json, "$.b[0]", 4.1}, json, nil}, // insert element in array does nothing + {f1, sql.Row{json, "$.b[5]", 4.1}, `{"a": 1, "b": [2, 3, 4.1], "c": {"d": "foo"}}`, nil}, // insert element in array out of range + {f1, sql.Row{json, "$.b.c", 4}, json, nil}, // insert nested in array does nothing + {f1, sql.Row{json, "$.a[0]", 4.1}, json, nil}, // struct as array does nothing + {f1, sql.Row{json, "$[0]", 4.1}, json, nil}, // struct does nothing. + {f1, sql.Row{json, "$.[0]", 4.1}, nil, fmt.Errorf("Invalid JSON path expression. Expected field name after '.' at character 2 of $.[0]")}, // improper struct indexing + {f1, sql.Row{json, "foo", "test"}, nil, fmt.Errorf("Invalid JSON path expression. Path must start with '$'")}, // invalid path + {f1, sql.Row{json, "$.c.*", "test"}, nil, fmt.Errorf("Invalid JSON path expression. Expected field name after '.' at character 4 of $.c.*")}, // path contains * wildcard + {f1, sql.Row{json, "$.c.**", "test"}, nil, fmt.Errorf("Invalid JSON path expression. Expected field name after '.' at character 4 of $.c.**")}, // path contains ** wildcard + {f1, sql.Row{1, "$.c.**", "test"}, nil, sql.ErrInvalidJSONArgument.New(1, "json_insert")}, // path contains ** wildcard + {f1, sql.Row{`()`, "$.c.**", "test"}, nil, sql.ErrInvalidJSONText.New(1, "json_insert", "()")}, // path contains ** wildcard + {f1, sql.Row{json, "$", 10.1}, json, nil}, // whole document no opt + {f1, sql.Row{nil, "$", 42.7}, nil, nil}, // sql-null document returns sql-null + {f1, sql.Row{"null", "$", 42.7}, "null", nil}, // json-null document returns json-null + {f1, sql.Row{json, nil, 10}, nil, nil}, // if any path is null, return null // mysql> select JSON_INSERT(JSON_ARRAY(), "$[2]", 1 , "$[2]", 2 ,"$[2]", 3 ,"$[2]", 4); // +------------------------------------------------------------------------+ @@ -99,7 +103,8 @@ func TestInsert(t *testing.T) { req.Equal(expect, result) } else { - req.Error(tstC.err, err) + req.Error(err) + req.Equal(err.Error(), tstC.err.Error()) } }) } diff --git a/sql/expression/function/json/json_keys.go b/sql/expression/function/json/json_keys.go index a8dd01ef80..8c62d19cfe 100644 --- a/sql/expression/function/json/json_keys.go +++ b/sql/expression/function/json/json_keys.go @@ -91,7 +91,7 @@ func (j *JSONKeys) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { doc, err := getJSONDocumentFromRow(ctx, row, j.JSON) if err != nil { - return nil, err + return nil, getJsonFunctionError("json_keys", 1, err) } if doc == nil { return nil, nil @@ -105,7 +105,7 @@ func (j *JSONKeys) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, nil } - js, err := jsonpath.JsonPathLookup(doc.Val, path.(string)) + js, err := types.LookupJSONValue(doc, *path) if err != nil { if errors.Is(err, jsonpath.ErrKeyError) { return nil, nil @@ -113,7 +113,16 @@ func (j *JSONKeys) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, err } - switch v := js.(type) { + if js == nil { + return nil, nil + } + + val, err := js.ToInterface() + if err != nil { + return nil, err + } + + switch v := val.(type) { case map[string]any: res := make([]string, 0) for k := range v { diff --git a/sql/expression/function/json/json_keys_test.go b/sql/expression/function/json/json_keys_test.go index eb27706253..cb10e6ab05 100644 --- a/sql/expression/function/json/json_keys_test.go +++ b/sql/expression/function/json/json_keys_test.go @@ -37,7 +37,7 @@ func TestJSONKeys(t *testing.T) { f sql.Expression row sql.Row exp interface{} - err bool + err error }{ { f: f1, @@ -52,7 +52,7 @@ func TestJSONKeys(t *testing.T) { { f: f1, row: sql.Row{1}, - err: true, + err: sql.ErrInvalidJSONArgument.New(1, "json_keys"), }, { f: f1, @@ -72,7 +72,7 @@ func TestJSONKeys(t *testing.T) { { f: f1, row: sql.Row{`badjson`}, - err: true, + err: sql.ErrInvalidJSONText.New(1, "json_keys", "badjson"), }, { f: f1, @@ -98,7 +98,7 @@ func TestJSONKeys(t *testing.T) { { f: f2, row: sql.Row{`{"a": [1, false]}`, 123}, - err: true, + err: fmt.Errorf("Invalid JSON path expression"), }, { f: f2, @@ -133,7 +133,7 @@ func TestJSONKeys(t *testing.T) { { f: f2, row: sql.Row{`{"a": 1, "b": [2, 3], "c": {"d": "foo"}}`, "$["}, - err: true, + err: fmt.Errorf("Invalid JSON path expression. Missing ']'"), }, } @@ -145,8 +145,9 @@ func TestJSONKeys(t *testing.T) { t.Run(strings.Join(args, ", "), func(t *testing.T) { require := require.New(t) result, err := tt.f.Eval(sql.NewEmptyContext(), tt.row) - if tt.err { + if tt.err != nil { require.Error(err) + require.Equal(tt.err.Error(), err.Error()) } else { require.NoError(err) } diff --git a/sql/expression/function/json/json_length.go b/sql/expression/function/json/json_length.go index 702481b48a..3ffe2d7f6c 100644 --- a/sql/expression/function/json/json_length.go +++ b/sql/expression/function/json/json_length.go @@ -17,9 +17,6 @@ package json import ( "fmt" - "github.com/dolthub/jsonpath" - "gopkg.in/src-d/go-errors.v1" - "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/expression" "github.com/dolthub/go-mysql-server/sql/types" @@ -78,7 +75,7 @@ func (j *JsonLength) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { doc, err := getJSONDocumentFromRow(ctx, row, j.JSON) if err != nil { - return nil, err + return nil, getJsonFunctionError("json_length", 1, err) } if doc == nil { return nil, nil @@ -98,15 +95,21 @@ func (j *JsonLength) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, strErr } - res, err := jsonpath.JsonPathLookup(doc.Val, path) + res, err := types.LookupJSONValue(doc, path) + if err != nil { + return nil, err + } + + if res == nil { + return nil, nil + } + + val, err := res.ToInterface() if err != nil { - if errors.Is(err, jsonpath.ErrKeyError) { - return nil, nil - } return nil, err } - switch v := res.(type) { + switch v := val.(type) { case nil: return nil, nil case []interface{}: diff --git a/sql/expression/function/json/json_length_test.go b/sql/expression/function/json/json_length_test.go index e9d7726fe9..c94350f99c 100644 --- a/sql/expression/function/json/json_length_test.go +++ b/sql/expression/function/json/json_length_test.go @@ -36,7 +36,7 @@ func TestJsonLength(t *testing.T) { f sql.Expression row sql.Row exp interface{} - err bool + err error }{ { f: f1, @@ -82,7 +82,7 @@ func TestJsonLength(t *testing.T) { { f: f2, row: sql.Row{`{"a": [1, false]}`, 123}, - err: true, + err: fmt.Errorf("Invalid JSON path expression. Path must start with '$', but received: '123'"), }, { f: f2, @@ -114,6 +114,16 @@ func TestJsonLength(t *testing.T) { row: sql.Row{`{"a": 1, "b": [2, 3], "c": {"d": "foo"}}`, "$.d"}, exp: nil, }, + { + f: f2, + row: sql.Row{1, "$.d"}, + err: sql.ErrInvalidJSONArgument.New(1, "json_length"), + }, + { + f: f2, + row: sql.Row{"asdf", "$.d"}, + err: sql.ErrInvalidJSONText.New(1, "json_length", "asdf"), + }, } for _, tt := range testCases { @@ -125,8 +135,9 @@ func TestJsonLength(t *testing.T) { require := require.New(t) // any error case will result in output of 'false' value result, err := tt.f.Eval(sql.NewEmptyContext(), tt.row) - if tt.err { + if tt.err != nil { require.Error(err) + require.Equal(tt.err.Error(), err.Error()) } else { require.NoError(err) } diff --git a/sql/expression/function/json/json_merge_patch.go b/sql/expression/function/json/json_merge_patch.go index 418d3316d8..86537814e1 100644 --- a/sql/expression/function/json/json_merge_patch.go +++ b/sql/expression/function/json/json_merge_patch.go @@ -110,23 +110,32 @@ func (j *JSONMergePatch) IsNullable() bool { func (j *JSONMergePatch) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { initDoc, err := getJSONDocumentFromRow(ctx, row, j.JSONs[0]) if err != nil { - return nil, err + return nil, getJsonFunctionError("json_merge_patch", 1, err) } if initDoc == nil { return nil, nil } - result := types.DeepCopyJson(initDoc.Val) - for _, json := range j.JSONs[1:] { - var doc *types.JSONDocument + val, err := initDoc.ToInterface() + if err != nil { + return nil, err + } + + result := types.DeepCopyJson(val) + for i, json := range j.JSONs[1:] { + var doc sql.JSONWrapper doc, err = getJSONDocumentFromRow(ctx, row, json) if err != nil { - return nil, err + return nil, getJsonFunctionError("json_merge_patch", i+2, err) } if doc == nil { return nil, nil } - result = merge(result, doc.Val, true) + val, err = doc.ToInterface() + if err != nil { + return nil, err + } + result = merge(result, val, true) } return types.JSONDocument{Val: result}, nil } diff --git a/sql/expression/function/json/json_merge_patch_test.go b/sql/expression/function/json/json_merge_patch_test.go index 353fb0a48d..f1db8fff89 100644 --- a/sql/expression/function/json/json_merge_patch_test.go +++ b/sql/expression/function/json/json_merge_patch_test.go @@ -32,7 +32,7 @@ func TestJSONMergePatch(t *testing.T) { f sql.Expression row sql.Row exp interface{} - err bool + err error }{ { f: f2, @@ -130,6 +130,26 @@ func TestJSONMergePatch(t *testing.T) { row: sql.Row{`{"a": 1, "b": 2}`, `{"a": {"one": false, "two": 2.55, "e": 8}}`, `"single value"`}, exp: types.MustJSON(`"single value"`), }, + { + f: f3, + row: sql.Row{1, `{"a": {"one": false, "two": 2.55, "e": 8}}`, `{"a": 1, "b": 2}`}, + err: sql.ErrInvalidJSONArgument.New(1, "json_merge_patch"), + }, + { + f: f3, + row: sql.Row{`{"a": {"one": false, "two": 2.55, "e": 8}}`, 1, `{"a": 1, "b": 2}`}, + err: sql.ErrInvalidJSONArgument.New(2, "json_merge_patch"), + }, + { + f: f3, + row: sql.Row{`{`, `{"a": {"one": false, "two": 2.55, "e": 8}}`, `{"a": 1, "b": 2}`}, + err: sql.ErrInvalidJSONText.New(1, "json_merge_patch", "{"), + }, + { + f: f3, + row: sql.Row{`{"a": {"one": false, "two": 2.55, "e": 8}}`, `}`, `{"a": 1, "b": 2}`}, + err: sql.ErrInvalidJSONText.New(2, "json_merge_patch", "}"), + }, } for _, tt := range testCases { @@ -140,8 +160,9 @@ func TestJSONMergePatch(t *testing.T) { t.Run(strings.Join(args, ", "), func(t *testing.T) { require := require.New(t) res, err := tt.f.Eval(sql.NewEmptyContext(), tt.row) - if tt.err { + if tt.err != nil { require.Error(err) + require.Equal(tt.err.Error(), err.Error()) return } require.NoError(err) diff --git a/sql/expression/function/json/json_merge_preserve.go b/sql/expression/function/json/json_merge_preserve.go index b96bd80db9..64d9861a5e 100644 --- a/sql/expression/function/json/json_merge_preserve.go +++ b/sql/expression/function/json/json_merge_preserve.go @@ -118,23 +118,31 @@ func (j *JSONMergePreserve) IsNullable() bool { func (j *JSONMergePreserve) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { initDoc, err := getJSONDocumentFromRow(ctx, row, j.JSONs[0]) if err != nil { - return nil, err + return nil, getJsonFunctionError("json_merge_preserve", 1, err) } if initDoc == nil { return nil, nil } - result := types.DeepCopyJson(initDoc.Val) - for _, json := range j.JSONs[1:] { - var doc *types.JSONDocument + val, err := initDoc.ToInterface() + if err != nil { + return nil, err + } + result := types.DeepCopyJson(val) + for i, json := range j.JSONs[1:] { + var doc sql.JSONWrapper doc, err = getJSONDocumentFromRow(ctx, row, json) if err != nil { - return nil, err + return nil, getJsonFunctionError("json_merge_preserve", i+2, err) } if doc == nil { return nil, nil } - result = merge(result, doc.Val, false) + val, err = doc.ToInterface() + if err != nil { + return nil, err + } + result = merge(result, val, false) } return types.JSONDocument{Val: result}, nil } diff --git a/sql/expression/function/json/json_merge_preserve_test.go b/sql/expression/function/json/json_merge_preserve_test.go index 1e232af9d1..7a8e5b930c 100644 --- a/sql/expression/function/json/json_merge_preserve_test.go +++ b/sql/expression/function/json/json_merge_preserve_test.go @@ -34,7 +34,7 @@ func TestJSONMergePreserve(t *testing.T) { f sql.Expression row sql.Row exp interface{} - err bool + err error }{ { f: f2, @@ -146,6 +146,26 @@ func TestJSONMergePreserve(t *testing.T) { row: sql.Row{`{"a": 1, "b": 2}`, `{"a": {"one": false, "two": 2.55, "e": 8}}`, `"single value"`}, exp: types.MustJSON(`[{"a": [1, {"e": 8, "one": false, "two": 2.55}], "b": 2}, "single value"]`), }, + { + f: f3, + row: sql.Row{1, `{"a": {"one": false, "two": 2.55, "e": 8}}`, `{"a": 1, "b": 2}`}, + err: sql.ErrInvalidJSONArgument.New(1, "json_merge_preserve"), + }, + { + f: f3, + row: sql.Row{`{"a": {"one": false, "two": 2.55, "e": 8}}`, 1, `{"a": 1, "b": 2}`}, + err: sql.ErrInvalidJSONArgument.New(2, "json_merge_preserve"), + }, + { + f: f3, + row: sql.Row{`{`, `{"a": {"one": false, "two": 2.55, "e": 8}}`, `{"a": 1, "b": 2}`}, + err: sql.ErrInvalidJSONText.New(1, "json_merge_preserve", "{"), + }, + { + f: f3, + row: sql.Row{`{"a": {"one": false, "two": 2.55, "e": 8}}`, `}`, `{"a": 1, "b": 2}`}, + err: sql.ErrInvalidJSONText.New(2, "json_merge_preserve", "}"), + }, } for _, tt := range testCases { @@ -156,8 +176,9 @@ func TestJSONMergePreserve(t *testing.T) { t.Run(strings.Join(args, ", "), func(t *testing.T) { require := require.New(t) res, err := tt.f.Eval(sql.NewEmptyContext(), tt.row) - if tt.err { + if tt.err != nil { require.Error(err) + require.Equal(tt.err.Error(), err.Error()) return } require.NoError(err) diff --git a/sql/expression/function/json/json_overlaps.go b/sql/expression/function/json/json_overlaps.go index 5dd1660e89..8e3b58e5df 100644 --- a/sql/expression/function/json/json_overlaps.go +++ b/sql/expression/function/json/json_overlaps.go @@ -185,21 +185,29 @@ func (j *JSONOverlaps) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) left, err := getJSONDocumentFromRow(ctx, row, j.Left) if err != nil { - return nil, err + return nil, getJsonFunctionError("json_overlaps", 1, err) } if left == nil { return nil, nil } + leftVal, err := left.ToInterface() + if err != nil { + return nil, err + } right, err := getJSONDocumentFromRow(ctx, row, j.Right) if err != nil { - return nil, err + return nil, getJsonFunctionError("json_overlaps", 2, err) } if right == nil { return nil, nil } + rightVal, err := right.ToInterface() + if err != nil { + return nil, err + } - return overlaps(left.Val, right.Val), nil + return overlaps(leftVal, rightVal), nil } // Children implements sql.Expression diff --git a/sql/expression/function/json/json_overlaps_test.go b/sql/expression/function/json/json_overlaps_test.go index 2fb2e97f45..0559b48b34 100644 --- a/sql/expression/function/json/json_overlaps_test.go +++ b/sql/expression/function/json/json_overlaps_test.go @@ -34,23 +34,38 @@ func TestJSONOverlaps(t *testing.T) { f sql.Expression row sql.Row exp interface{} - err bool + err error }{ // errors { f: f2, row: sql.Row{``}, - err: true, + err: sql.ErrInvalidJSONText.New(1, "json_overlaps", ``), }, { f: f2, row: sql.Row{``, ``}, - err: true, + err: sql.ErrInvalidJSONText.New(1, "json_overlaps", ``), }, { f: f2, row: sql.Row{`asdf`, `badjson`}, - err: true, + err: sql.ErrInvalidJSONText.New(1, "json_overlaps", `asdf`), + }, + { + f: f2, + row: sql.Row{`{}`, `badjson`}, + err: sql.ErrInvalidJSONText.New(2, "json_overlaps", `badjson`), + }, + { + f: f2, + row: sql.Row{1, `{}`}, + err: sql.ErrInvalidJSONArgument.New(1, "json_overlaps"), + }, + { + f: f2, + row: sql.Row{`{}`, 1}, + err: sql.ErrInvalidJSONArgument.New(2, "json_overlaps"), }, // nulls @@ -227,8 +242,9 @@ func TestJSONOverlaps(t *testing.T) { t.Run(strings.Join(args, ", "), func(t *testing.T) { require := require.New(t) result, err := tt.f.Eval(sql.NewEmptyContext(), tt.row) - if tt.err { + if tt.err != nil { require.Error(err) + require.Equal(tt.err.Error(), err.Error()) } else { require.NoError(err) } diff --git a/sql/expression/function/json/json_pretty.go b/sql/expression/function/json/json_pretty.go index f65a1cd542..f7f1a24070 100644 --- a/sql/expression/function/json/json_pretty.go +++ b/sql/expression/function/json/json_pretty.go @@ -78,12 +78,16 @@ func (j *JSONPretty) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { doc, err := getJSONDocumentFromRow(ctx, row, j.Child) if err != nil { - return nil, err + return nil, getJsonFunctionError("json_pretty", 1, err) } if doc == nil { return nil, nil } - res, err := json.MarshalIndent(doc.Val, "", " ") + val, err := doc.ToInterface() + if err != nil { + return nil, err + } + res, err := json.MarshalIndent(val, "", " ") if err != nil { return nil, err } diff --git a/sql/expression/function/json/json_pretty_test.go b/sql/expression/function/json/json_pretty_test.go index d2133e5d77..8d5201925e 100644 --- a/sql/expression/function/json/json_pretty_test.go +++ b/sql/expression/function/json/json_pretty_test.go @@ -28,17 +28,20 @@ func TestJSONPretty(t *testing.T) { testCases := []struct { arg sql.Expression exp interface{} - err bool + err error }{ { arg: expression.NewLiteral(``, types.Text), - err: true, + err: sql.ErrInvalidJSONText.New(1, "json_pretty", ""), }, { arg: expression.NewLiteral(`badjson`, types.Text), - err: true, + err: sql.ErrInvalidJSONText.New(1, "json_pretty", "badjson"), + }, + { + arg: expression.NewLiteral(1, types.Int64), + err: sql.ErrInvalidJSONArgument.New(1, "json_pretty"), }, - { arg: expression.NewLiteral(nil, types.Null), exp: nil, @@ -112,8 +115,9 @@ func TestJSONPretty(t *testing.T) { require := require.New(t) f := NewJSONPretty(tt.arg) res, err := f.Eval(sql.NewEmptyContext(), nil) - if tt.err { + if tt.err != nil { require.Error(err) + require.Equal(tt.err.Error(), err.Error()) return } require.NoError(err) diff --git a/sql/expression/function/json/json_remove.go b/sql/expression/function/json/json_remove.go index d57bda85e2..4bfacb85b3 100644 --- a/sql/expression/function/json/json_remove.go +++ b/sql/expression/function/json/json_remove.go @@ -104,7 +104,7 @@ func (j JSONRemove) Description() string { func (j JSONRemove) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { doc, err := getMutableJSONVal(ctx, row, j.doc) if err != nil || doc == nil { - return nil, err + return nil, getJsonFunctionError("json_remove", 1, err) } for _, path := range j.paths { @@ -116,7 +116,7 @@ func (j JSONRemove) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, nil } - doc, _, err = doc.Remove(path.(string)) + doc, _, err = doc.Remove(*path) if err != nil { return nil, err } diff --git a/sql/expression/function/json/json_remove_test.go b/sql/expression/function/json/json_remove_test.go index 782da8ce02..1da65c205d 100644 --- a/sql/expression/function/json/json_remove_test.go +++ b/sql/expression/function/json/json_remove_test.go @@ -42,23 +42,25 @@ func TestRemove(t *testing.T) { expected interface{} err error }{ - {f1, sql.Row{json, "$.a"}, `{"b": [2, 3], "c": {"d": "foo"}}`, nil}, // remove existing - {f1, sql.Row{json, "$.b[0]"}, `{"a": 1, "b": [3], "c": {"d": "foo"}}`, nil}, // remove existing array element - {f1, sql.Row{json, "$.c.d"}, `{"a": 1, "b": [2, 3], "c": {}}`, nil}, // remove existing nested - {f1, sql.Row{json, "$.c"}, `{"a": 1, "b": [2, 3]}`, nil}, // remove existing object - {f1, sql.Row{json, "$.a.e"}, json, nil}, // remove nothing when path not found - {f1, sql.Row{json, "$.c[5]"}, json, nil}, // remove nothing when path not found - {f1, sql.Row{json, "$.b[last]"}, `{"a": 1, "b": [2], "c": {"d": "foo"}}`, nil}, // remove last element in array - {f1, sql.Row{json, "$.b[5]"}, json, nil}, // remove nothing when array index out of bounds - {f1, sql.Row{json, "$[0]"}, json, nil}, // remove nothing when provided a bogus path. - {f1, sql.Row{json, "$.[0]"}, nil, ErrInvalidPath}, // improper struct indexing - {f1, sql.Row{json, "foo", "test"}, nil, ErrInvalidPath}, // invalid path - {f1, sql.Row{json, "$.c.*", "test"}, nil, ErrPathWildcard}, // path contains * wildcard - {f1, sql.Row{json, "$.c.**", "test"}, nil, ErrPathWildcard}, // path contains ** wildcard - {f1, sql.Row{json, "$"}, nil, fmt.Errorf("The path expression '$' is not allowed in this context.")}, // whole document - {f1, sql.Row{nil, "$"}, nil, nil}, // null document - {f2, sql.Row{json, "$.foo", nil}, nil, nil}, // if any path is null, return null - {f2, sql.Row{json, "$.a", "$.b"}, `{"c": {"d": "foo"}}`, nil}, // remove multiple paths + {f1, sql.Row{json, "$.a"}, `{"b": [2, 3], "c": {"d": "foo"}}`, nil}, // remove existing + {f1, sql.Row{json, "$.b[0]"}, `{"a": 1, "b": [3], "c": {"d": "foo"}}`, nil}, // remove existing array element + {f1, sql.Row{json, "$.c.d"}, `{"a": 1, "b": [2, 3], "c": {}}`, nil}, // remove existing nested + {f1, sql.Row{json, "$.c"}, `{"a": 1, "b": [2, 3]}`, nil}, // remove existing object + {f1, sql.Row{json, "$.a.e"}, json, nil}, // remove nothing when path not found + {f1, sql.Row{json, "$.c[5]"}, json, nil}, // remove nothing when path not found + {f1, sql.Row{json, "$.b[last]"}, `{"a": 1, "b": [2], "c": {"d": "foo"}}`, nil}, // remove last element in array + {f1, sql.Row{json, "$.b[5]"}, json, nil}, // remove nothing when array index out of bounds + {f1, sql.Row{json, "$[0]"}, json, nil}, // remove nothing when provided a bogus path. + {f1, sql.Row{json, "$.[0]"}, nil, fmt.Errorf("Invalid JSON path expression. Expected field name after '.' at character 2 of $.[0]")}, // improper struct indexing + {f1, sql.Row{json, "foo", "test"}, nil, fmt.Errorf("Invalid JSON path expression. Path must start with '$'")}, // invalid path + {f1, sql.Row{json, "$.c.*", "test"}, nil, fmt.Errorf("Invalid JSON path expression. Expected field name after '.' at character 4 of $.c.*")}, // path contains * wildcard + {f1, sql.Row{json, "$.c.**", "test"}, nil, fmt.Errorf("Invalid JSON path expression. Expected field name after '.' at character 4 of $.c.**")}, // path contains ** wildcard + {f1, sql.Row{json, "$"}, nil, fmt.Errorf("The path expression '$' is not allowed in this context.")}, // whole document + {f1, sql.Row{1, "$"}, nil, sql.ErrInvalidJSONArgument.New(1, "json_remove")}, + {f1, sql.Row{"}{", "$"}, nil, sql.ErrInvalidJSONText.New(1, "json_remove", "}{")}, + {f1, sql.Row{nil, "$"}, nil, nil}, // null document + {f2, sql.Row{json, "$.foo", nil}, nil, nil}, // if any path is null, return null + {f2, sql.Row{json, "$.a", "$.b"}, `{"c": {"d": "foo"}}`, nil}, // remove multiple paths } for _, tstC := range testCases { @@ -85,7 +87,8 @@ func TestRemove(t *testing.T) { req.Equal(expect, result) } else { - req.Error(tstC.err, err) + req.Error(err) + req.Equal(tstC.err.Error(), err.Error()) } }) } diff --git a/sql/expression/function/json/json_replace.go b/sql/expression/function/json/json_replace.go index dd0d663c43..5947c1aad8 100644 --- a/sql/expression/function/json/json_replace.go +++ b/sql/expression/function/json/json_replace.go @@ -75,7 +75,7 @@ func (j JSONReplace) IsNullable() bool { func (j JSONReplace) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { doc, err := getMutableJSONVal(ctx, row, j.doc) if err != nil || doc == nil { - return nil, err + return nil, getJsonFunctionError("json_replace", 1, err) } pairs := make([]pathValPair, 0, len(j.pathVals)/2) diff --git a/sql/expression/function/json/json_replace_test.go b/sql/expression/function/json/json_replace_test.go index 1b840b6358..2bec2a72e3 100644 --- a/sql/expression/function/json/json_replace_test.go +++ b/sql/expression/function/json/json_replace_test.go @@ -15,6 +15,7 @@ package json import ( + "fmt" "strings" "testing" @@ -41,25 +42,28 @@ func TestReplace(t *testing.T) { expected interface{} err error }{ - {f1, sql.Row{json, "$.a", 10.1}, `{"a": 10.1, "b": [2, 3], "c": {"d": "foo"}}`, nil}, // replace existing - {f1, sql.Row{json, "$.e", "new"}, json, nil}, // replace non-existing does nothing - {f1, sql.Row{json, "$.c.d", "test"}, `{"a": 1, "b": [2, 3], "c": {"d": "test"}}`, nil}, // replace nested - {f2, sql.Row{json, "$.a", 10.1, "$.e", "new"}, `{"a": 10.1, "b": [2, 3], "c": {"d": "foo"}}`, nil}, // replace multiple, one change. - {f1, sql.Row{json, "$.a.e", "test"}, json, nil}, // replace nested non-existent does nothing - {f1, sql.Row{json, "$.c.e", "test"}, json, nil}, // replace nested in existing struct missing field does nothing - {f1, sql.Row{json, "$.c[5]", 4.1}, json, nil}, // replace struct with indexing out of range - {f1, sql.Row{json, "$.b[0]", 4.1}, `{"a": 1, "b": [4.1, 3], "c": {"d": "foo"}}`, nil}, // replace element in array - {f1, sql.Row{json, "$.b[5]", 4.1}, json, nil}, // replace element in array out of range does nothing - {f1, sql.Row{json, "$.b.c", 4}, json, nil}, // replace nested in array does nothing - {f1, sql.Row{json, "$.a[0]", 4.1}, `{"a": 4.1, "b": [2, 3], "c": {"d": "foo"}}`, nil}, // replace scalar when treated as array - {f1, sql.Row{json, "$[0]", 4.1}, `4.1`, nil}, // replace root element when treated as array - {f1, sql.Row{json, "$.[0]", 4.1}, nil, ErrInvalidPath}, // improper struct indexing - {f1, sql.Row{json, "foo", "test"}, nil, ErrInvalidPath}, // invalid path - {f1, sql.Row{json, "$.c.*", "test"}, nil, ErrPathWildcard}, // path contains * wildcard - {f1, sql.Row{json, "$.c.**", "test"}, nil, ErrPathWildcard}, // path contains ** wildcard - {f1, sql.Row{json, "$", 10.1}, `10.1`, nil}, // replace root element - {f1, sql.Row{nil, "$", 42.7}, nil, nil}, // null document returns null - {f1, sql.Row{json, nil, 10}, nil, nil}, // if any path is null, return null + {f1, sql.Row{json, "$.a", 10.1}, `{"a": 10.1, "b": [2, 3], "c": {"d": "foo"}}`, nil}, // replace existing + {f1, sql.Row{json, "$.e", "new"}, json, nil}, // replace non-existing does nothing + {f1, sql.Row{json, "$.c.d", "test"}, `{"a": 1, "b": [2, 3], "c": {"d": "test"}}`, nil}, // replace nested + {f2, sql.Row{json, "$.a", 10.1, "$.e", "new"}, `{"a": 10.1, "b": [2, 3], "c": {"d": "foo"}}`, nil}, // replace multiple, one change. + {f1, sql.Row{json, "$.a.e", "test"}, json, nil}, // replace nested non-existent does nothing + {f1, sql.Row{json, "$.c.e", "test"}, json, nil}, // replace nested in existing struct missing field does nothing + {f1, sql.Row{json, "$.c[5]", 4.1}, json, nil}, // replace struct with indexing out of range + {f1, sql.Row{json, "$.b[0]", 4.1}, `{"a": 1, "b": [4.1, 3], "c": {"d": "foo"}}`, nil}, // replace element in array + {f1, sql.Row{json, "$.b[5]", 4.1}, json, nil}, // replace element in array out of range does nothing + {f1, sql.Row{json, "$.b.c", 4}, json, nil}, // replace nested in array does nothing + {f1, sql.Row{json, "$.a[0]", 4.1}, `{"a": 4.1, "b": [2, 3], "c": {"d": "foo"}}`, nil}, // replace scalar when treated as array + {f1, sql.Row{json, "$[0]", 4.1}, `4.1`, nil}, // replace root element when treated as array + {f1, sql.Row{json, "$.[0]", 4.1}, nil, fmt.Errorf("Invalid JSON path expression. Expected field name after '.' at character 2 of $.[0]")}, // improper struct indexing + {f1, sql.Row{json, "foo", "test"}, nil, fmt.Errorf("Invalid JSON path expression. Path must start with '$'")}, // invalid path + {f1, sql.Row{json, "$.c.*", "test"}, nil, fmt.Errorf("Invalid JSON path expression. Expected field name after '.' at character 4 of $.c.*")}, // path contains * wildcard + {f1, sql.Row{json, "$.c.**", "test"}, nil, fmt.Errorf("Invalid JSON path expression. Expected field name after '.' at character 4 of $.c.**")}, // path contains ** wildcard + {f1, sql.Row{1, "$[0]", 4.1}, `4.1`, sql.ErrInvalidJSONArgument.New(1, "json_replace")}, + {f1, sql.Row{``, "$[0]", 4.1}, `4.1`, sql.ErrInvalidJSONText.New(1, "json_replace", "")}, + {f1, sql.Row{json, "$[0]", 4.1}, `4.1`, nil}, + {f1, sql.Row{json, "$", 10.1}, `10.1`, nil}, // replace root element + {f1, sql.Row{nil, "$", 42.7}, nil, nil}, // null document returns null + {f1, sql.Row{json, nil, 10}, nil, nil}, // if any path is null, return null } for _, tstC := range testCases { @@ -86,7 +90,8 @@ func TestReplace(t *testing.T) { req.Equal(expect, result) } else { - req.Error(tstC.err, err) + req.Error(err) + req.Equal(tstC.err.Error(), err.Error()) } }) } diff --git a/sql/expression/function/json/json_search.go b/sql/expression/function/json/json_search.go index f9c347c5bf..ac53967a62 100644 --- a/sql/expression/function/json/json_search.go +++ b/sql/expression/function/json/json_search.go @@ -206,7 +206,7 @@ func (j *JSONSearch) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { doc, err := getJSONDocumentFromRow(ctx, row, j.JSON) if err != nil { - return nil, err + return nil, getJsonFunctionError("json_search", 1, err) } if doc == nil { return nil, nil @@ -289,16 +289,21 @@ func (j *JSONSearch) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { } else if newPath == nil { return nil, nil } else { - path = newPath.(string) + path = *newPath } paths = append(paths, path) } } + val, err := doc.ToInterface() + if err != nil { + return nil, err + } + seen := make(map[string]struct{}) var results []string for _, path := range paths { - js, err := jsonpath.JsonPathLookup(doc.Val, path) + js, err := jsonpath.JsonPathLookup(val, path) if err != nil && !errors.Is(err, jsonpath.ErrKeyError) { return nil, err } diff --git a/sql/expression/function/json/json_search_test.go b/sql/expression/function/json/json_search_test.go index 3107dfe97f..af0bb83e9f 100644 --- a/sql/expression/function/json/json_search_test.go +++ b/sql/expression/function/json/json_search_test.go @@ -53,28 +53,33 @@ func TestJSONSearch(t *testing.T) { f sql.Expression row sql.Row exp interface{} - err bool + err error skip bool }{ + { + f: f3, + row: sql.Row{1, "one", "abc"}, + err: sql.ErrInvalidJSONArgument.New(1, "json_search"), + }, { f: f3, row: sql.Row{"", "one", "abc"}, - err: true, + err: sql.ErrInvalidJSONText.New(1, "json_search", ""), }, { f: f3, row: sql.Row{json, "NotOneOrAll", "abc"}, - err: true, + err: errOneOrAll, }, { f: f3, row: sql.Row{json, "one ", "abc"}, - err: true, + err: errOneOrAll, }, { f: f4, row: sql.Row{json, "one", "abc", "badescape"}, - err: true, + err: errBadEscape, }, { @@ -246,8 +251,9 @@ func TestJSONSearch(t *testing.T) { } require := require.New(t) res, err := tt.f.Eval(sql.NewEmptyContext(), tt.row) - if tt.err { + if tt.err != nil { require.Error(err) + require.Equal(tt.err.Error(), err.Error()) return } require.NoError(err) diff --git a/sql/expression/function/json/json_set.go b/sql/expression/function/json/json_set.go index 8a177e47bd..d9cab5d172 100644 --- a/sql/expression/function/json/json_set.go +++ b/sql/expression/function/json/json_set.go @@ -123,7 +123,7 @@ func (j *JSONSet) IsNullable() bool { func (j *JSONSet) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { doc, err := getMutableJSONVal(ctx, row, j.JSONDoc) if err != nil || doc == nil { - return nil, err + return nil, getJsonFunctionError("json_set", 1, err) } pairs := make([]pathValPair, 0, len(j.PathAndVals)/2) diff --git a/sql/expression/function/json/json_set_test.go b/sql/expression/function/json/json_set_test.go index b4e283059a..30a393ef46 100644 --- a/sql/expression/function/json/json_set_test.go +++ b/sql/expression/function/json/json_set_test.go @@ -16,6 +16,7 @@ package json import ( json2 "encoding/json" + "fmt" "strconv" "strings" "testing" @@ -67,10 +68,12 @@ func TestJSONSet(t *testing.T) { {f1, sql.Row{json, "$.b.c", 4}, `{"a": 1, "b": [2, 3], "c": {"d": "foo"}}`, nil}, // set nested in array does nothing {f1, sql.Row{json, "$.a[0]", 4.1}, `{"a": 4.1, "b": [2, 3], "c": {"d": "foo"}}`, nil}, // update single element with indexing {f1, sql.Row{json, "$[0]", 4.1}, `4.1`, nil}, // struct indexing - {f1, sql.Row{json, "$.[0]", 4.1}, nil, ErrInvalidPath}, // improper struct indexing - {f1, sql.Row{json, "foo", "test"}, nil, ErrInvalidPath}, // invalid path - {f1, sql.Row{json, "$.c.*", "test"}, nil, ErrPathWildcard}, // path contains * wildcard - {f1, sql.Row{json, "$.c.**", "test"}, nil, ErrPathWildcard}, // path contains ** wildcard + {f1, sql.Row{json, "$.[0]", 4.1}, nil, fmt.Errorf("Invalid JSON path expression. Expected field name after '.' at character 2 of $.[0]")}, // improper struct indexing + {f1, sql.Row{json, "foo", "test"}, nil, fmt.Errorf("Invalid JSON path expression. Path must start with '$'")}, // invalid path + {f1, sql.Row{json, "$.c.*", "test"}, nil, fmt.Errorf("Invalid JSON path expression. Expected field name after '.' at character 4 of $.c.*")}, // path contains * wildcard + {f1, sql.Row{json, "$.c.**", "test"}, nil, fmt.Errorf("Invalid JSON path expression. Expected field name after '.' at character 4 of $.c.**")}, // path contains ** wildcard + {f1, sql.Row{1, "$", 10.1}, `10.1`, sql.ErrInvalidJSONArgument.New(1, "json_set")}, // whole document + {f1, sql.Row{"#", "$", 10.1}, `10.1`, sql.ErrInvalidJSONText.New(1, "json_set", "#")}, // whole document {f1, sql.Row{json, "$", 10.1}, `10.1`, nil}, // whole document {f1, sql.Row{nil, "$", 42.7}, nil, nil}, // null document {f1, sql.Row{json, nil, 10}, nil, nil}, // if any path is null, return null @@ -124,7 +127,8 @@ func TestJSONSet(t *testing.T) { require.Equal(expect, result) } else { - require.Error(tt.err, err) + require.Error(err) + require.Equal(tt.err.Error(), err.Error()) } }) } diff --git a/sql/expression/function/json/json_type.go b/sql/expression/function/json/json_type.go index f13613eaa7..21e34d6470 100644 --- a/sql/expression/function/json/json_type.go +++ b/sql/expression/function/json/json_type.go @@ -82,13 +82,18 @@ func (j JSONType) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { doc, err := getJSONDocumentFromRow(ctx, row, j.JSON) if err != nil { - return nil, err + return nil, getJsonFunctionError("json_type", 1, err) } if doc == nil { return "NULL", nil } - switch v := doc.Val.(type) { + val, err := doc.ToInterface() + if err != nil { + return nil, err + } + + switch v := val.(type) { case nil: return "NULL", nil case bool: diff --git a/sql/expression/function/json/json_type_test.go b/sql/expression/function/json/json_type_test.go index 8f1bf31fb5..1c5de44019 100644 --- a/sql/expression/function/json/json_type_test.go +++ b/sql/expression/function/json/json_type_test.go @@ -36,37 +36,37 @@ func TestJSONType(t *testing.T) { f sql.Expression row sql.Row exp interface{} - err bool + err error }{ { f: f1, row: sql.Row{``}, - err: true, + err: sql.ErrInvalidJSONText.New(1, "json_type", ""), }, { f: f1, row: sql.Row{`badjson`}, - err: true, + err: sql.ErrInvalidJSONText.New(1, "json_type", "badjson"), }, { f: f1, row: sql.Row{true}, - err: true, + err: sql.ErrInvalidJSONArgument.New(1, "json_type"), }, { f: f1, row: sql.Row{1}, - err: true, + err: sql.ErrInvalidJSONArgument.New(1, "json_type"), }, { f: f1, row: sql.Row{1.5}, - err: true, + err: sql.ErrInvalidJSONArgument.New(1, "json_type"), }, { f: f1, row: sql.Row{decimal.New(15, -1)}, - err: true, + err: sql.ErrInvalidJSONArgument.New(1, "json_type"), }, { @@ -168,8 +168,9 @@ func TestJSONType(t *testing.T) { t.Run(strings.Join(args, ", "), func(t *testing.T) { require := require.New(t) result, err := tt.f.Eval(sql.NewEmptyContext(), tt.row) - if tt.err { + if tt.err != nil { require.Error(err) + require.Equal(tt.err.Error(), err.Error()) } else { require.NoError(err) } diff --git a/sql/expression/function/json/json_value.go b/sql/expression/function/json/json_value.go index ddb7b9e102..dc4a225fcf 100644 --- a/sql/expression/function/json/json_value.go +++ b/sql/expression/function/json/json_value.go @@ -19,9 +19,7 @@ import ( "fmt" "strings" - "github.com/dolthub/jsonpath" "github.com/dolthub/vitess/go/sqltypes" - "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/expression" @@ -87,18 +85,24 @@ func (j *JsonValue) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { span, ctx := ctx.Span("function.JsonValue") defer span.End() - js, err := j.JSON.Eval(ctx, row) + js, err := getSearchableJSONVal(ctx, row, j.JSON) if err != nil { - return nil, err + return nil, getJsonFunctionError("json_value", 1, err) } - // sql NULLs, should result in sql NULLs. + // If the document is SQL NULL, the result is SQL NULL if js == nil { return nil, nil } - jsonData, err := GetJSONFromWrapperOrCoercibleString(js) - if err != nil { - return nil, err + // json NULLs also result in sql NULLs. + cmp, err := types.CompareJSON(js, types.JSONDocument{Val: nil}) + if cmp == 0 { + return nil, nil + } + + searchable, ok := js.(sql.JSONWrapper) + if !ok { + return fmt.Errorf("expected types.JSONValue, found: %T", js), nil } path, err := j.Path.Eval(ctx, row) @@ -106,21 +110,22 @@ func (j *JsonValue) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, err } - res, err := jsonpath.JsonPathLookup(jsonData, path.(string)) - if err != nil { + var res interface{} + res, err = types.LookupJSONValue(searchable, path.(string)) + if err != nil || res == nil { return nil, err } - switch r := res.(type) { - case nil: + // This is NOT CORRECT, but it prevents existing tests from regressing when the jsonpath module returns [] for + // bad lookups on arrays, instead of an error. Note that this will cause lookups that expect [] to return incorrect + // results. + // See https://github.com/dolthub/dolt/issues/7905 for more information. + cmp, err = types.CompareJSON(res, types.JSONDocument{Val: []interface{}{}}) + if err != nil { + return nil, err + } + if cmp == 0 { return nil, nil - case []interface{}: - if len(r) == 0 { - return nil, nil - } - res = types.JSONDocument{Val: res} - case map[string]interface{}: - res = types.JSONDocument{Val: res} } if j.Typ != nil { @@ -163,13 +168,11 @@ func (j *JsonValue) String() string { return fmt.Sprintf("json_value(%s)", strings.Join(parts, ", ")) } -var InvalidJsonArgument = errors.NewKind("invalid data type for JSON data in argument 1 to function json_value; a JSON string or JSON type is required") - // GetJSONFromWrapperOrCoercibleString takes a valid argument for JSON functions (either a JSON wrapper type or a string) // and unwraps the JSON, or coerces the string into JSON. The return value can return any type that can be stored in // a JSON column, not just maps. For a complete list, see // https://dev.mysql.com/doc/refman/8.3/en/json-attribute-functions.html#function_json-type -func GetJSONFromWrapperOrCoercibleString(js interface{}) (jsonData interface{}, err error) { +func GetJSONFromWrapperOrCoercibleString(js interface{}, functionName string, argumentPosition int) (jsonData interface{}, err error) { // The first parameter can be either JSON or a string. switch jsType := js.(type) { case string: @@ -184,6 +187,6 @@ func GetJSONFromWrapperOrCoercibleString(js interface{}) (jsonData interface{}, case sql.JSONWrapper: return jsType.ToInterface() default: - return nil, InvalidJsonArgument.New() + return nil, sql.ErrInvalidJSONArgument.New(argumentPosition, functionName) } } diff --git a/sql/expression/function/json/json_value_test.go b/sql/expression/function/json/json_value_test.go index 8d02c012e8..51bdf00725 100644 --- a/sql/expression/function/json/json_value_test.go +++ b/sql/expression/function/json/json_value_test.go @@ -15,6 +15,7 @@ package json import ( + "fmt" "strings" "testing" @@ -35,6 +36,7 @@ func TestJsonValue(t *testing.T) { typ sql.Type path string exp interface{} + err error }{ {row: sql.Row{`null`}, exp: nil}, {row: sql.Row{`1`}, exp: "1"}, @@ -46,12 +48,14 @@ func TestJsonValue(t *testing.T) { {row: sql.Row{`[1, false]`}, path: "$[0]", exp: "1"}, {row: sql.Row{`[1, {"a": 1}]`}, path: "$[1].a", typ: types.Int64, exp: int64(1)}, {row: sql.Row{`[1, {"a": 1}]`}, path: "$[1]", typ: types.JSON, exp: types.MustJSON(`{"a": 1}`)}, + {row: sql.Row{1}, path: `$.f`, err: sql.ErrInvalidJSONArgument.New(1, "json_value")}, + {row: sql.Row{`}`}, path: `$.f`, err: sql.ErrInvalidJSONText.New(1, "json_value", "}")}, } for _, tt := range tests { args := make([]string, len(tt.row)) for i, a := range tt.row { - args[i] = a.(string) + args[i] = fmt.Sprint(a) } if tt.path == "" { tt.path = "$" @@ -71,8 +75,14 @@ func TestJsonValue(t *testing.T) { f, _ := NewJsonValue(args...) require := require.New(t) // any error case will result in output of 'false' value - result, _ := f.Eval(sql.NewEmptyContext(), tt.row) - require.Equal(tt.exp, result) + result, err := f.Eval(sql.NewEmptyContext(), tt.row) + if tt.err == nil { + require.NoError(err) + require.Equal(tt.exp, result) + } else { + require.Error(err) + require.Equal(tt.err.Error(), err.Error()) + } }) } } diff --git a/sql/rowexec/rel.go b/sql/rowexec/rel.go index b543d9e5fa..6b0e6f522e 100644 --- a/sql/rowexec/rel.go +++ b/sql/rowexec/rel.go @@ -189,7 +189,7 @@ func (b *BaseBuilder) buildJSONTable(ctx *sql.Context, n *plan.JSONTable, row sq return &jsonTableRowIter{}, nil } - jsonData, err := json.GetJSONFromWrapperOrCoercibleString(data) + jsonData, err := json.GetJSONFromWrapperOrCoercibleString(data, "json_table", 1) if err != nil { return nil, err } diff --git a/sql/types/json_value.go b/sql/types/json_value.go index 7de56f8776..3a464f95a5 100644 --- a/sql/types/json_value.go +++ b/sql/types/json_value.go @@ -135,20 +135,6 @@ func (doc JSONDocument) String() string { return result } -// Contains returns nil in case of a nil value for either the doc.Val or candidate. Otherwise -// it returns a bool -func (doc JSONDocument) Contains(candidate sql.JSONWrapper) (val interface{}, err error) { - candidateVal, err := candidate.ToInterface() - if err != nil { - return nil, err - } - return ContainsJSON(doc.Val, candidateVal) -} - -func (doc JSONDocument) Extract(path string) (sql.JSONWrapper, error) { - return LookupJSONValue(doc, path) -} - // LazyJSONDocument is an implementation of sql.JSONWrapper that wraps a JSON string and defers deserializing // it unless needed. This is more efficient for queries that interact with JSON values but don't care about their structure. type LazyJSONDocument struct { @@ -209,12 +195,13 @@ func LookupJSONValue(j sql.JSONWrapper, path string) (sql.JSONWrapper, error) { if err.Error() == "should start with '$'" { err = fmt.Errorf("Invalid JSON path expression. Path must start with '$', but received: '%s'", path) } + // jsonpath poorly handles unmatched [] in paths. + if strings.Contains(err.Error(), "len(tail) should") { + return nil, fmt.Errorf("Invalid JSON path expression. Missing ']'") + } return nil, err } - // Lookup(obj) throws an error if obj is nil. We want lookups on a json null - // to always result in sql NULL, except in the case of the identity lookup - // $. r, err := j.ToInterface() if err != nil { return nil, err @@ -223,6 +210,13 @@ func LookupJSONValue(j sql.JSONWrapper, path string) (sql.JSONWrapper, error) { return nil, nil } + // For non-object, non-array candidates, if the path is not "$", return SQL NULL + _, isObject := r.(JsonObject) + _, isArray := r.(JsonArray) + if !isObject && !isArray { + return nil, nil + } + val, err := c.Lookup(r) if err != nil { if strings.Contains(err.Error(), "key error") { @@ -263,9 +257,9 @@ func ConcatenateJSONValues(ctx *sql.Context, vals ...sql.JSONWrapper) (sql.JSONW return JSONDocument{Val: arr}, nil } -func ContainsJSON(a, b interface{}) (interface{}, error) { - if a == nil || b == nil { - return nil, nil +func ContainsJSON(a, b interface{}) (bool, error) { + if a == nil { + return b == nil, nil } switch a := a.(type) {