From f10917a7e243bb1c8922f5c2bc6430fa5235ea3d Mon Sep 17 00:00:00 2001 From: liujian Date: Sat, 15 Jun 2024 15:29:40 +0800 Subject: [PATCH] Update xsql support for PB non-scalar fields --- src/xsql/db_test.go | 52 ++++++++++++++++++++++++++++++++++++++++++--- src/xsql/fetcher.go | 27 +++++++++++++++++------ 2 files changed, 70 insertions(+), 9 deletions(-) diff --git a/src/xsql/db_test.go b/src/xsql/db_test.go index 5c6d2ed..f152ef4 100644 --- a/src/xsql/db_test.go +++ b/src/xsql/db_test.go @@ -24,6 +24,25 @@ type Test struct { Enum Enum `xsql:"enum" json:"-"` } +type TestJsonStruct struct { + Test + Json JsonItem `xsql:"json"` +} + +type TestJsonStructPtr struct { + Test + Json *JsonItem `xsql:"json"` +} + +type TestJsonSlice struct { + Test + Json []int `xsql:"json"` +} + +type JsonItem struct { + Foo string `xsql:"foo"` +} + func (t Test) TableName() string { return "xsql" } @@ -79,10 +98,12 @@ CREATE TABLE #xsql# ( #bar# datetime DEFAULT NULL, #bool# int NOT NULL DEFAULT '0', #enum# int NOT NULL DEFAULT '0', + #json# json DEFAULT NULL, PRIMARY KEY (#id#) ) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci; -INSERT INTO #xsql# (#id#, #foo#, #bar#, #bool#, #enum#) VALUES (1, 'v', '2022-04-14 23:49:48', 1, 1); -INSERT INTO #xsql# (#id#, #foo#, #bar#, #bool#, #enum#) VALUES (2, 'v1', '2022-04-14 23:50:00', 1, 1); +INSERT INTO #xsql# (#id#, #foo#, #bar#, #bool#, #enum#, #json#) VALUES (1, 'v', '2022-04-12 23:50:00', 1, 1, '{"foo":"bar"}'); +INSERT INTO #xsql# (#id#, #foo#, #bar#, #bool#, #enum#, #json#) VALUES (2, 'v1', '2022-04-13 23:50:00', 1, 1, '[1,2]'); +INSERT INTO #xsql# (#id#, #foo#, #bar#, #bool#, #enum#, #json#) VALUES (3, 'v2', '2022-04-14 23:50:00', 1, 1, null); ` DB := newDB() _, err := DB.Exec(strings.ReplaceAll(q, "#", "`")) @@ -483,7 +504,6 @@ func TestTxRollback(t *testing.T) { func TestPbTimestamp(t *testing.T) { a := assert.New(t) - DB := newDB() // Insert @@ -508,3 +528,29 @@ func TestPbTimestamp(t *testing.T) { a.IsType(×tamppb.Timestamp{}, test2.Bar) a.Equal(test2.Bar.Seconds, now.Seconds) } + +func TestFetchPbJson(t *testing.T) { + a := assert.New(t) + DB := newDB() + + var test1 TestJsonStruct + err := DB.First(&test1, "SELECT * FROM xsql WHERE id = 1") + if err != nil { + log.Fatal(err) + } + a.NotEmpty(test1.Json) + + var test2 TestJsonStructPtr + err = DB.First(&test2, "SELECT * FROM xsql WHERE id = 1") + if err != nil { + log.Fatal(err) + } + a.NotEmpty(test2.Json) + + var test3 TestJsonSlice + err = DB.First(&test3, "SELECT * FROM xsql WHERE id = 2") + if err != nil { + log.Fatal(err) + } + a.NotEmpty(test3.Json) +} diff --git a/src/xsql/fetcher.go b/src/xsql/fetcher.go index 78e0287..002b6b2 100644 --- a/src/xsql/fetcher.go +++ b/src/xsql/fetcher.go @@ -2,6 +2,7 @@ package xsql import ( "database/sql" + "encoding/json" "errors" "fmt" "github.com/sijms/go-ora/v2" @@ -351,8 +352,7 @@ func (t *Fetcher) mapped(row *Row, tag string, value reflect.Value, typ reflect. default: if !res.Empty() { vTyp := reflect.ValueOf(v).Type().String() - // 如果结构体是time.Time类型,执行转换 - if typ.String() == "time.Time" { + if typ.String() == "time.Time" { // 如果结构体是time.Time类型,执行转换 if vTyp == "time.Time" { // parseTime=true v = res.Value() @@ -364,9 +364,7 @@ func (t *Fetcher) mapped(row *Row, tag string, value reflect.Value, typ reflect. return fmt.Errorf("time parse fail for field %s: %v", tag, e) } } - } - // 如果结构体是*timestamppb.Timestamp类型,执行转换 - if typ.String() == "*timestamppb.Timestamp" { + } else if typ.String() == "*timestamppb.Timestamp" { // 如果结构体是*timestamppb.Timestamp类型,执行转换 if vTyp != "*timestamppb.Timestamp" { if t, e := time.ParseInLocation(t.options.TimeLayout, res.String(), t.options.TimeLocation); e == nil { v = timestamppb.New(t) @@ -374,10 +372,27 @@ func (t *Fetcher) mapped(row *Row, tag string, value reflect.Value, typ reflect. return fmt.Errorf("time parse fail for field %s: %v", tag, e) } } + } else if typ.Kind() == reflect.Ptr || typ.Kind() == reflect.Struct || typ.Kind() == reflect.Slice || typ.Kind() == reflect.Array { // 非标量用JSON反序列化处理 + jsonString := res.String() + var newInstance reflect.Value + if typ.Kind() == reflect.Ptr { + newInstance = reflect.New(typ.Elem()) // 创建的都是指针 + } else { + newInstance = reflect.New(typ) // 创建的都是指针 + } + if e := json.Unmarshal([]byte(jsonString), newInstance.Interface()); e != nil { + return fmt.Errorf("json unmarshal error for field %s: %v", tag, e) + } + if typ.Kind() == reflect.Ptr { + v = newInstance.Interface() + } else { + v = newInstance.Elem().Interface() // 获取的是非指针 + } } } } - // 追加异常信息 + + // 设置值 defer func() { if e := recover(); e != nil { err = fmt.Errorf("type mismatch for field %s: %v", tag, e)