Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

copier支持类型转换 #200

Merged
merged 5 commits into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# 开发中
- [copier: ReflectCopier copier支持类型转换](https://github.com/ecodeclub/ekit/issues/197)
- [mapx: TreeMap 添加 Keys 和 Values 方法](https://github.com/ecodeclub/ekit/pull/181)
- [mapx: 修正 HashMap 中使用泛型不当的地方](https://github.com/ecodeclub/ekit/pull/186)
- [pool: 重构TaskPool测试用例](https://github.com/ecodeclub/ekit/pull/178)
Expand Down
37 changes: 37 additions & 0 deletions bean/copier/converter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright 2021 ecodeclub
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package copier

import (
"time"
)

type Converter[Src any, Dst any] interface {
Convert(src Src) (Dst, error)
}

type Time2String struct {
Pattern string
}

func (t Time2String) Convert(src time.Time) (string, error) {
return src.Format(t.Pattern), nil
}
hookokoko marked this conversation as resolved.
Show resolved Hide resolved

type ConverterFunc[Src any, Dst any] func(src Src) (Dst, error)

func (cf ConverterFunc[Src, Dst]) Convert(src Src) (Dst, error) {
return cf(src)
}
23 changes: 23 additions & 0 deletions bean/copier/copy.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,12 @@ type Copier[Src any, Dst any] interface {
type options struct {
// ignoreFields 执行复制操作时,需要忽略的字段
ignoreFields *set.MapSet[string]
// convertFields 执行转换的field和转化接口的泛型包装
convertFields map[string]converterWrapper
}

type converterWrapper func(src any) (any, error)

func newOptions() *options {
return &options{}
}
Expand Down Expand Up @@ -65,3 +69,22 @@ func IgnoreFields(fields ...string) option.Option[options] {
}
}
}

func ConvertField[Src any, Dst any](field string, converter Converter[Src, Dst]) option.Option[options] {
return func(opt *options) {
if field == "" || converter == nil {
return
}
if opt.convertFields == nil {
opt.convertFields = make(map[string]converterWrapper, 16)
hookokoko marked this conversation as resolved.
Show resolved Hide resolved
}
opt.convertFields[field] = func(src any) (any, error) {
var dst Dst
srcVal, ok := src.(Src)
if !ok {
flycash marked this conversation as resolved.
Show resolved Hide resolved
return dst, ErrConvertFieldTypeNotMatch
}
return converter.Convert(srcVal)
}
}
}
5 changes: 5 additions & 0 deletions bean/copier/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,15 @@
package copier

import (
"errors"
"fmt"
"reflect"
)

var (
ErrConvertFieldTypeNotMatch = errors.New("ekit: 转化字段类型不匹配")
hookokoko marked this conversation as resolved.
Show resolved Hide resolved
)

// newErrTypeError copier 不支持的类型
func newErrTypeError(typ reflect.Type) error {
return fmt.Errorf("ekit: copier 入口只支持 Struct 不支持类型 %v, 种类 %v", typ, typ.Kind())
Expand Down
89 changes: 68 additions & 21 deletions bean/copier/reflect_copier.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package copier

import (
"reflect"
"time"

"github.com/ecodeclub/ekit/bean/option"
)
Expand Down Expand Up @@ -50,7 +51,7 @@ type fieldNode struct {
}

// NewReflectCopier 如果类型不匹配, 创建时直接检查报错.
func NewReflectCopier[Src any, Dst any]() (*ReflectCopier[Src, Dst], error) {
func NewReflectCopier[Src any, Dst any](opts ...option.Option[options]) (*ReflectCopier[Src, Dst], error) {
src := new(Src)
srcTyp := reflect.TypeOf(src).Elem()
dst := new(Dst)
Expand All @@ -72,6 +73,11 @@ func NewReflectCopier[Src any, Dst any]() (*ReflectCopier[Src, Dst], error) {
copier := &ReflectCopier[Src, Dst]{
flycash marked this conversation as resolved.
Show resolved Hide resolved
rootField: root,
}

opt := newOptions()
option.Apply(opt, opts...)
copier.options = opt

return copier, nil
}

Expand All @@ -98,17 +104,12 @@ func createFieldNodes(root *fieldNode, srcTyp, dstTyp reflect.Type) error {
continue
}
srcFieldTypStruct := srcTyp.Field(srcIndex)
if srcFieldTypStruct.Type.Kind() != dstFieldTypStruct.Type.Kind() {
return newErrKindNotMatchError(srcFieldTypStruct.Type.Kind(), dstFieldTypStruct.Type.Kind(), dstFieldTypStruct.Name)
}

if srcFieldTypStruct.Type.Kind() == reflect.Pointer {
if srcFieldTypStruct.Type.Elem().Kind() != dstFieldTypStruct.Type.Elem().Kind() {
return newErrKindNotMatchError(srcFieldTypStruct.Type.Kind(), dstFieldTypStruct.Type.Kind(), dstFieldTypStruct.Name)
}
if srcFieldTypStruct.Type.Elem().Kind() == reflect.Pointer {
return newErrMultiPointer(dstFieldTypStruct.Name)
}
if srcFieldTypStruct.Type.Kind() == reflect.Pointer && srcFieldTypStruct.Type.Elem().Kind() == reflect.Pointer {
return newErrMultiPointer(srcFieldTypStruct.Name)
}
if dstFieldTypStruct.Type.Kind() == reflect.Pointer && dstFieldTypStruct.Type.Elem().Kind() == reflect.Pointer {
return newErrMultiPointer(dstFieldTypStruct.Name)
}

child := fieldNode{
Expand All @@ -123,16 +124,18 @@ func createFieldNodes(root *fieldNode, srcTyp, dstTyp reflect.Type) error {
fieldDstTyp := dstFieldTypStruct.Type
if fieldSrcTyp.Kind() == reflect.Pointer {
fieldSrcTyp = fieldSrcTyp.Elem()
}

if fieldDstTyp.Kind() == reflect.Pointer {
fieldDstTyp = fieldDstTyp.Elem()
}

if isShadowCopyType(fieldSrcTyp.Kind()) {
// 内置类型,但不匹配,如别名、map和slice
if fieldSrcTyp != fieldDstTyp {
return newErrTypeNotMatchError(srcFieldTypStruct.Type, dstFieldTypStruct.Type, dstFieldTypStruct.Name)
}
// 说明当前节点是叶子节点, 直接拷贝
child.isLeaf = true
} else if fieldSrcTyp == reflect.TypeOf(time.Time{}) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个地方加一个注释,因为也不仅仅是 time.Time,类似于 sql.NullXXX 的应该也是这么处理的。所以将来我们可能需要考虑允许用户指定什么样的类型是一个整体,不用递归下去。

然后引入一个 atomicTypes 字段,里面放着所有的类似于 time,Time 这种被看做整体的类型。同时有一个 defaultActomicTypes 作为包变量,atomicTypes 的默认取值就是这个默认值。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我现在的实现是在ReflectCopier结构体中加了一个atomicTypes,但是目前的实现看没有用到。直接就用defaultActomicTypes了。后续是要将ReflectCopier.atomicTypes允许用户修改,然后赋值给包变量defaultActomicTypes吗?

child.isLeaf = true
} else if fieldSrcTyp.Kind() == reflect.Struct {
if err := createFieldNodes(&child, fieldSrcTyp, fieldDstTyp); err != nil {
return err
Expand Down Expand Up @@ -160,9 +163,13 @@ func (r *ReflectCopier[Src, Dst]) Copy(src *Src, opts ...option.Option[options])
// 3. 如果 Src 和 Dst 中匹配的字段,其类型都是结构体,或者都是结构体指针,则会深入复制
// 4. 否则,忽略字段
func (r *ReflectCopier[Src, Dst]) CopyTo(src *Src, dst *Dst, opts ...option.Option[options]) error {
opt := newOptions()
option.Apply(opt, opts...)
r.options = opt
if r.options == nil {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

options 在它初始化 ReflectCopier 的时候就创建一个出来。正常这个 Copier 都是会被复用的,所有不需要特别担心性能问题。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我发现这里之前的实现有问题。

也就是,实际上 ReflectCopier 里面的 options 应该是整个 Copier 维度的 options,但是单词调用是可以被覆盖的。

所以也就是你在这里,要把 ReflectCopier 里面的 option 复制出来一份。为了规避内存逃逸的问题,ReflectCopier 中的 options 不要用指针。

在这里,你会把 ReflectCopier 里面的 options 和 opts 进行合并,合并后的结果才是你最终的结果。

举例来说:默认初始化 ReflectCopier 的时候,我用了 ConvertField("A", converter1),但是当我在调用 Copy 方法的时候,我用了 ConvertField("A". converter2) 和 ConvertField("B", converter3),那么最终生效的就是 A converter2, 和 B converter3。

当我再一次发起 Copy 调用的时候,我依旧用的是 A converter1,也就是 Copier 本身的配置。

opt := newOptions()
option.Apply(opt, opts...)
r.options = opt
} else {
option.Apply(r.options, opts...)
}

return r.copyToWithTree(src, dst)
}
Expand All @@ -177,24 +184,64 @@ func (r *ReflectCopier[Src, Dst]) copyToWithTree(src *Src, dst *Dst) error {
}

func (r *ReflectCopier[Src, Dst]) copyTreeNode(srcTyp reflect.Type, srcValue reflect.Value, dstType reflect.Type, dstValue reflect.Value, root *fieldNode) error {
originSrcVal := srcValue
originDstVal := dstValue
if srcValue.Kind() == reflect.Pointer {
if srcValue.IsNil() {
return nil
}
if dstValue.IsNil() {
dstValue.Set(reflect.New(dstType.Elem()))
}
srcValue = srcValue.Elem()
srcTyp = srcTyp.Elem()
}

if dstValue.Kind() == reflect.Pointer {
flycash marked this conversation as resolved.
Show resolved Hide resolved
if dstValue.IsNil() {
dstValue.Set(reflect.New(dstType.Elem()))
}
dstValue = dstValue.Elem()
dstType = dstType.Elem()
}

// 执行拷贝
if root.isLeaf {
if dstValue.CanSet() {
convert, ok := r.options.convertFields[root.name]
hookokoko marked this conversation as resolved.
Show resolved Hide resolved
// 获取convert失败,就需要检测类型是否匹配
if !ok && srcTyp.Kind() != dstType.Kind() {
return newErrKindNotMatchError(srcTyp.Kind(), dstType.Kind(), root.name)
}
if !ok && srcTyp != dstType {
hookokoko marked this conversation as resolved.
Show resolved Hide resolved
return newErrTypeNotMatchError(srcTyp, dstType, root.name)
}
// 获取convert失败,类型匹配就直接set
if !ok && dstValue.CanSet() {
if srcValue.IsZero() {
return nil
}
dstValue.Set(srcValue)
return nil
}

// 字段执行转换函数时,需要用到原始类型进行判断
srcConv, err := convert(originSrcVal.Interface())
flycash marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return err
}

srcConvType := reflect.TypeOf(srcConv)
srcConvVal := reflect.ValueOf(srcConv)
// 待设置的value和转换获取的value类型不匹配
if srcConvType != originDstVal.Type() {
return newErrTypeNotMatchError(srcConvType, originDstVal.Type(), root.name)
}

if srcConvType.Kind() == reflect.Ptr {
srcConvVal = srcConvVal.Elem()
}
hookokoko marked this conversation as resolved.
Show resolved Hide resolved

if dstValue.CanSet() && srcConvVal.IsValid() {
dstValue.Set(srcConvVal)
}

return nil
}

Expand Down
Loading