Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

copier支持类型转换 #200

Merged
merged 5 commits into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# 开发中
- [copier: ReflectCopier copier支持类型转换](https://github.com/ecodeclub/ekit/issues/197)
- [mapx: TreeMap 添加 Keys 和 Values 方法](https://github.com/ecodeclub/ekit/pull/181)
- [mapx: 修正 HashMap 中使用泛型不当的地方](https://github.com/ecodeclub/ekit/pull/186)
- [pool: 重构TaskPool测试用例](https://github.com/ecodeclub/ekit/pull/178)
Expand Down
25 changes: 25 additions & 0 deletions bean/copier/converter/converter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright 2021 ecodeclub
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package converter

type Converter[Src any, Dst any] interface {
Convert(src Src) (Dst, error)
}

type ConverterFunc[Src any, Dst any] func(src Src) (Dst, error)

func (cf ConverterFunc[Src, Dst]) Convert(src Src) (Dst, error) {
return cf(src)
}
25 changes: 25 additions & 0 deletions bean/copier/converter/time2string.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright 2021 ecodeclub
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package converter

import "time"

type Time2String struct {
Pattern string
}

func (t Time2String) Convert(src time.Time) (string, error) {
return src.Format(t.Pattern), nil
}
28 changes: 26 additions & 2 deletions bean/copier/copy.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package copier

import (
"github.com/ecodeclub/ekit/bean/copier/converter"
"github.com/ecodeclub/ekit/bean/option"
"github.com/ecodeclub/ekit/set"
)
Expand All @@ -35,10 +36,14 @@ type Copier[Src any, Dst any] interface {
type options struct {
// ignoreFields 执行复制操作时,需要忽略的字段
ignoreFields *set.MapSet[string]
// convertFields 执行转换的field和转化接口的泛型包装
convertFields map[string]converterWrapper
}

func newOptions() *options {
return &options{}
type converterWrapper func(src any) (any, error)

func newOptions() options {
return options{}
}

// InIgnoreFields 判断 str 是不是在 ignoreFields 里面
Expand All @@ -65,3 +70,22 @@ func IgnoreFields(fields ...string) option.Option[options] {
}
}
}

func ConvertField[Src any, Dst any](field string, converter converter.Converter[Src, Dst]) option.Option[options] {
return func(opt *options) {
if field == "" || converter == nil {
return
}
if opt.convertFields == nil {
opt.convertFields = make(map[string]converterWrapper, 8)
}
opt.convertFields[field] = func(src any) (any, error) {
var dst Dst
srcVal, ok := src.(Src)
if !ok {
flycash marked this conversation as resolved.
Show resolved Hide resolved
return dst, errConvertFieldTypeNotMatch
}
return converter.Convert(srcVal)
}
}
}
5 changes: 5 additions & 0 deletions bean/copier/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,15 @@
package copier

import (
"errors"
"fmt"
"reflect"
)

var (
errConvertFieldTypeNotMatch = errors.New("ekit: 转化字段类型不匹配")
)

// newErrTypeError copier 不支持的类型
func newErrTypeError(typ reflect.Type) error {
return fmt.Errorf("ekit: copier 入口只支持 Struct 不支持类型 %v, 种类 %v", typ, typ.Kind())
Expand Down
149 changes: 115 additions & 34 deletions bean/copier/reflect_copier.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,17 @@ package copier

import (
"reflect"
"time"

"github.com/ecodeclub/ekit/set"

"github.com/ecodeclub/ekit/bean/option"
)

var defaultAtomicTypes = []reflect.Type{
reflect.TypeOf(time.Time{}),
}

// ReflectCopier 基于反射的实现
// ReflectCopier 是浅拷贝
type ReflectCopier[Src any, Dst any] struct {
Expand All @@ -28,7 +35,11 @@ type ReflectCopier[Src any, Dst any] struct {
rootField fieldNode

// options 执行复制操作时的可选配置
options *options
// 如果默认配置和Copy()/CopyTo()中的配置同名,会替换defaultOptions同名内容
// 初始化时的默认配置,仅作为记录,执行时会拷贝到options中
defaultOptions options

atomicTypes []reflect.Type
}

// fieldNode 字段的前缀树
Expand All @@ -50,7 +61,7 @@ type fieldNode struct {
}

// NewReflectCopier 如果类型不匹配, 创建时直接检查报错.
func NewReflectCopier[Src any, Dst any]() (*ReflectCopier[Src, Dst], error) {
func NewReflectCopier[Src any, Dst any](opts ...option.Option[options]) (*ReflectCopier[Src, Dst], error) {
src := new(Src)
srcTyp := reflect.TypeOf(src).Elem()
dst := new(Dst)
Expand All @@ -65,18 +76,24 @@ func NewReflectCopier[Src any, Dst any]() (*ReflectCopier[Src, Dst], error) {
if dstTyp.Kind() != reflect.Struct {
return nil, newErrTypeError(dstTyp)
}
if err := createFieldNodes(&root, srcTyp, dstTyp); err != nil {
return nil, err
}

copier := &ReflectCopier[Src, Dst]{
flycash marked this conversation as resolved.
Show resolved Hide resolved
rootField: root,
atomicTypes: defaultAtomicTypes,
}

if err := copier.createFieldNodes(&root, srcTyp, dstTyp); err != nil {
return nil, err
}
copier.rootField = root

defaultOpts := newOptions()
option.Apply(&defaultOpts, opts...)
copier.defaultOptions = defaultOpts
return copier, nil
}

// createFieldNodes 递归创建 field 的前缀树, srcTyp 和 dstTyp 只能是结构体
func createFieldNodes(root *fieldNode, srcTyp, dstTyp reflect.Type) error {
func (r *ReflectCopier[Src, Dst]) createFieldNodes(root *fieldNode, srcTyp, dstTyp reflect.Type) error {

fieldMap := map[string]int{}
for i := 0; i < srcTyp.NumField(); i++ {
Expand All @@ -98,17 +115,12 @@ func createFieldNodes(root *fieldNode, srcTyp, dstTyp reflect.Type) error {
continue
}
srcFieldTypStruct := srcTyp.Field(srcIndex)
if srcFieldTypStruct.Type.Kind() != dstFieldTypStruct.Type.Kind() {
return newErrKindNotMatchError(srcFieldTypStruct.Type.Kind(), dstFieldTypStruct.Type.Kind(), dstFieldTypStruct.Name)
}

if srcFieldTypStruct.Type.Kind() == reflect.Pointer {
if srcFieldTypStruct.Type.Elem().Kind() != dstFieldTypStruct.Type.Elem().Kind() {
return newErrKindNotMatchError(srcFieldTypStruct.Type.Kind(), dstFieldTypStruct.Type.Kind(), dstFieldTypStruct.Name)
}
if srcFieldTypStruct.Type.Elem().Kind() == reflect.Pointer {
return newErrMultiPointer(dstFieldTypStruct.Name)
}
if srcFieldTypStruct.Type.Kind() == reflect.Pointer && srcFieldTypStruct.Type.Elem().Kind() == reflect.Pointer {
return newErrMultiPointer(srcFieldTypStruct.Name)
}
if dstFieldTypStruct.Type.Kind() == reflect.Pointer && dstFieldTypStruct.Type.Elem().Kind() == reflect.Pointer {
return newErrMultiPointer(dstFieldTypStruct.Name)
}

child := fieldNode{
Expand All @@ -123,18 +135,22 @@ func createFieldNodes(root *fieldNode, srcTyp, dstTyp reflect.Type) error {
fieldDstTyp := dstFieldTypStruct.Type
if fieldSrcTyp.Kind() == reflect.Pointer {
fieldSrcTyp = fieldSrcTyp.Elem()
}

if fieldDstTyp.Kind() == reflect.Pointer {
fieldDstTyp = fieldDstTyp.Elem()
}

if isShadowCopyType(fieldSrcTyp.Kind()) {
// 内置类型,但不匹配,如别名、map和slice
if fieldSrcTyp != fieldDstTyp {
return newErrTypeNotMatchError(srcFieldTypStruct.Type, dstFieldTypStruct.Type, dstFieldTypStruct.Name)
}
// 说明当前节点是叶子节点, 直接拷贝
child.isLeaf = true
} else if r.isAtomicType(fieldSrcTyp) {
// 指定可作为一个整体的类型,不用递归
// 同上,当当前节点是叶子节点时, 直接拷贝
child.isLeaf = true
} else if fieldSrcTyp.Kind() == reflect.Struct {
if err := createFieldNodes(&child, fieldSrcTyp, fieldDstTyp); err != nil {
if err := r.createFieldNodes(&child, fieldSrcTyp, fieldDstTyp); err != nil {
return err
}
} else {
Expand All @@ -160,49 +176,105 @@ func (r *ReflectCopier[Src, Dst]) Copy(src *Src, opts ...option.Option[options])
// 3. 如果 Src 和 Dst 中匹配的字段,其类型都是结构体,或者都是结构体指针,则会深入复制
// 4. 否则,忽略字段
func (r *ReflectCopier[Src, Dst]) CopyTo(src *Src, dst *Dst, opts ...option.Option[options]) error {
opt := newOptions()
option.Apply(opt, opts...)
r.options = opt
localOption := r.copyDefaultOptions()
option.Apply(&localOption, opts...)
return r.copyToWithTree(src, dst, localOption)
}

return r.copyToWithTree(src, dst)
// copyDefaultOptions 复制默认配置
func (r *ReflectCopier[Src, Dst]) copyDefaultOptions() options {
localOption := newOptions()
// 复制ignoreFields default配置
if r.defaultOptions.ignoreFields != nil {
ignoreFields := set.NewMapSet[string](8)
for _, key := range r.defaultOptions.ignoreFields.Keys() {
ignoreFields.Add(key)
}
localOption.ignoreFields = ignoreFields
}

// 复制convertFields default配置
for field, convert := range r.defaultOptions.convertFields {
if localOption.convertFields == nil {
localOption.convertFields = make(map[string]converterWrapper, 8)
}
localOption.convertFields[field] = convert
}
flycash marked this conversation as resolved.
Show resolved Hide resolved
return localOption
}

func (r *ReflectCopier[Src, Dst]) copyToWithTree(src *Src, dst *Dst) error {
func (r *ReflectCopier[Src, Dst]) copyToWithTree(src *Src, dst *Dst, opts options) error {
srcTyp := reflect.TypeOf(src)
dstTyp := reflect.TypeOf(dst)
srcValue := reflect.ValueOf(src)
dstValue := reflect.ValueOf(dst)

return r.copyTreeNode(srcTyp, srcValue, dstTyp, dstValue, &r.rootField)
return r.copyTreeNode(srcTyp, srcValue, dstTyp, dstValue, &r.rootField, opts)
}

func (r *ReflectCopier[Src, Dst]) copyTreeNode(srcTyp reflect.Type, srcValue reflect.Value, dstType reflect.Type, dstValue reflect.Value, root *fieldNode) error {
func (r *ReflectCopier[Src, Dst]) copyTreeNode(srcTyp reflect.Type, srcValue reflect.Value,
dstType reflect.Type, dstValue reflect.Value, root *fieldNode, opts options) error {
originSrcVal := srcValue
originDstVal := dstValue
if srcValue.Kind() == reflect.Pointer {
if srcValue.IsNil() {
return nil
}
if dstValue.IsNil() {
dstValue.Set(reflect.New(dstType.Elem()))
}
srcValue = srcValue.Elem()
srcTyp = srcTyp.Elem()
}

if dstValue.Kind() == reflect.Pointer {
flycash marked this conversation as resolved.
Show resolved Hide resolved
if dstValue.IsNil() {
dstValue.Set(reflect.New(dstType.Elem()))
}
dstValue = dstValue.Elem()
dstType = dstType.Elem()
}

// 执行拷贝
if root.isLeaf {
if dstValue.CanSet() {
convert, ok := opts.convertFields[root.name]
if !dstValue.CanSet() {
return nil
}
// 获取convert失败,就需要检测类型是否匹配,类型匹配就直接set
if !ok {
if srcTyp != dstType {
return newErrTypeNotMatchError(srcTyp, dstType, root.name)
}
if srcValue.IsZero() {
return nil
}
dstValue.Set(srcValue)
return nil
}

// 字段执行转换函数时,需要用到原始类型进行判断,set的时候也是根据原始value设置
if !originDstVal.CanSet() {
return nil
}
srcConv, err := convert(originSrcVal.Interface())
flycash marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return err
}

srcConvType := reflect.TypeOf(srcConv)
srcConvVal := reflect.ValueOf(srcConv)
// 待设置的value和转换获取的value类型不匹配
if srcConvType != originDstVal.Type() {
return newErrTypeNotMatchError(srcConvType, originDstVal.Type(), root.name)
}

originDstVal.Set(srcConvVal)
return nil
}

for i := range root.fields {
child := &root.fields[i]

// 只要结构体属性的名字在需要忽略的字段里面,就不走下面的复制逻辑
if r.options.InIgnoreFields(child.name) {
if opts.InIgnoreFields(child.name) {
continue
}

Expand All @@ -211,13 +283,22 @@ func (r *ReflectCopier[Src, Dst]) copyTreeNode(srcTyp reflect.Type, srcValue ref

childDstTyp := dstType.Field(child.dstIndex)
childDstValue := dstValue.Field(child.dstIndex)
if err := r.copyTreeNode(childSrcTyp.Type, childSrcValue, childDstTyp.Type, childDstValue, child); err != nil {
if err := r.copyTreeNode(childSrcTyp.Type, childSrcValue, childDstTyp.Type, childDstValue, child, opts); err != nil {
return err
}
}
return nil
}

func (r *ReflectCopier[Src, Dst]) isAtomicType(typ reflect.Type) bool {
for _, dt := range r.atomicTypes {
if dt == typ {
return true
}
}
return false
}

func isShadowCopyType(kind reflect.Kind) bool {
switch kind {
case reflect.Bool,
Expand Down
Loading