Skip to content

Commit

Permalink
Update xsql BuildTagValues Embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
onanying committed Mar 22, 2024
1 parent 6939ac8 commit 918bc47
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 13 deletions.
15 changes: 15 additions & 0 deletions src/xsql/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,21 @@ func TestUpdateBuildTagValues(t *testing.T) {
a.Empty(err)
}

func TestEmbeddingUpdateBuildTagValues(t *testing.T) {
a := assert.New(t)

DB := newDB()

test := EmbeddingTest{}
data, err := xsql.BuildTagValues(DB.Options.Tag, &test,
&test.Foo, "test_update_4",
)
a.Empty(err)

_, err = DB.Model(&test).Update(data, "id = ?", 8)
a.Empty(err)
}

func TestDelete(t *testing.T) {
a := assert.New(t)

Expand Down
47 changes: 34 additions & 13 deletions src/xsql/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,31 @@ func BuildTagValues(tagKey string, ptr interface{}, pairs ...interface{}) (map[s
}

result := make(map[string]interface{})
structValue := reflect.ValueOf(ptr).Elem()
structType := structValue.Type()
value := reflect.ValueOf(ptr).Elem()

if structType.Kind() != reflect.Struct {
if value.Kind() != reflect.Struct {
return nil, fmt.Errorf("xsql: ptr must be a pointer to a struct")
}

fieldsMap := map[string]reflect.Value{}
populateFieldsMap(value, fieldsMap)

for i := 0; i < len(pairs); i += 2 {
fieldPtrValue := reflect.ValueOf(pairs[i])
if fieldPtrValue.Kind() != reflect.Ptr || fieldPtrValue.Elem().Kind() == reflect.Ptr {
fieldPtr, ok := pairs[i].(interface{})
if !ok {
return nil, fmt.Errorf("xsql: argument at index %d is not a pointer", i)
}

fieldValue := reflect.ValueOf(fieldPtr)
if fieldValue.Kind() != reflect.Ptr || fieldValue.IsNil() {
return nil, fmt.Errorf("xsql: argument at index %d must be a non-nil pointer to a struct field", i)
}

var fieldName string
found := false
for j := 0; j < structValue.NumField(); j++ {
if structValue.Field(j).Addr().Interface() == pairs[i] {
fieldName = structType.Field(j).Tag.Get(tagKey)
if fieldName == "" {
return nil, fmt.Errorf("xsql: no struct field tag found for pointer at index %d", i)
}
var found bool
for name, field := range fieldsMap {
if field.Addr().Interface() == fieldPtr {
fieldName = name
found = true
break
}
Expand All @@ -41,9 +45,26 @@ func BuildTagValues(tagKey string, ptr interface{}, pairs ...interface{}) (map[s
return nil, fmt.Errorf("xsql: no matching struct field found for pointer at index %d", i)
}

// Set the field name and value in the map
result[fieldName] = pairs[i+1]
}

return result, nil
}

// populateFieldsMap is a recursive function that maps field names to their values,
// including fields from embedded structs.
func populateFieldsMap(v reflect.Value, fieldsMap map[string]reflect.Value) {
for i := 0; i < v.NumField(); i++ {
fieldValue := v.Field(i)
fieldType := v.Type().Field(i)
tag := fieldType.Tag.Get("xsql")
// If it's an embedded struct, we need to recurse into it
if fieldType.Anonymous && fieldValue.Type().Kind() == reflect.Struct {
populateFieldsMap(fieldValue, fieldsMap)
} else if tag != "" {
// Only add the field if it has the xsql tag
fieldName := tag
fieldsMap[fieldName] = fieldValue
}
}
}

0 comments on commit 918bc47

Please sign in to comment.