Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Serde Encode Options #466

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 78 additions & 19 deletions pkg/sr/serde.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,32 @@ var (

type (
// SerdeOpt is an option to configure a Serde.
SerdeOpt interface{ apply(*tserde) }
SerdeOpt interface{ applySerde(*tserde) }
serdeOpt struct{ fn func(*tserde) }

// EncodeOpt is an option to configure Serde.Encode and related functions.
EncodeOpt interface{ applyEncode(*encodeOpts) }
encodeOpt struct{ fn func(*encodeOpts) }
encodeOpts struct {
id int
index []int
}

// SerdeEncodeOpt is an option to configure Serde and Serde.Encode.
SerdeEncodeOpt interface {
SerdeOpt
EncodeOpt
}
serdeEncodeOpt struct {
serdeFn func(*tserde)
encodeFn func(*encodeOpts)
}
)

func (o serdeOpt) apply(t *tserde) { o.fn(t) }
func (o serdeOpt) applySerde(t *tserde) { o.fn(t) }
func (o encodeOpt) applyEncode(opts *encodeOpts) { o.fn(opts) }
func (o serdeEncodeOpt) applySerde(t *tserde) { o.serdeFn(t) }
func (o serdeEncodeOpt) applyEncode(opts *encodeOpts) { o.encodeFn(opts) }

// EncodeFn allows Serde to encode a value.
func EncodeFn(fn func(any) ([]byte, error)) SerdeOpt {
Expand Down Expand Up @@ -67,8 +88,16 @@ func GenerateFn(fn func() any) SerdeOpt {
// For more information, see where `message-indexes` are described in:
//
// https://docs.confluent.io/platform/current/schema-registry/serdes-develop/index.html#wire-format
func Index(index ...int) SerdeOpt {
return serdeOpt{func(t *tserde) { t.index = index }}
func Index(index ...int) SerdeEncodeOpt {
return serdeEncodeOpt{
serdeFn: func(t *tserde) { t.index = index },
encodeFn: func(o *encodeOpts) { o.index = index },
}
}

// ID forces Serde.Encode to use the specified schema ID.
func ID(id int) EncodeOpt {
return encodeOpt{func(opts *encodeOpts) { opts.id = id }}
}

type tserde struct {
Expand Down Expand Up @@ -140,10 +169,10 @@ func (s *Serde) SetDefaults(opts ...SerdeOpt) {
func (s *Serde) Register(id int, v any, opts ...SerdeOpt) {
var t tserde
for _, opt := range s.defaults {
opt.apply(&t)
opt.applySerde(&t)
}
for _, opt := range opts {
opt.apply(&t)
opt.applySerde(&t)
}

typeof := reflect.TypeOf(v)
Expand Down Expand Up @@ -221,17 +250,21 @@ func tserdeMapClone(m map[int]tserde, at int, index []int) map[int]tserde {

// Encode encodes a value according to the schema registry wire format and
// returns it. If EncodeFn was not used, this returns ErrNotRegistered.
func (s *Serde) Encode(v any) ([]byte, error) {
return s.AppendEncode(nil, v)
func (s *Serde) Encode(v any, opts ...EncodeOpt) ([]byte, error) {
return s.AppendEncode(nil, v, opts...)
}

// AppendEncode appends an encoded value to b according to the schema registry
// wire format and returns it. If EncodeFn was not used, this returns
// ErrNotRegistered.
func (s *Serde) AppendEncode(b []byte, v any) ([]byte, error) {
t, ok := s.loadTypes()[reflect.TypeOf(v)]
if !ok || (t.encode == nil && t.appendEncode == nil) {
return b, ErrNotRegistered
func (s *Serde) AppendEncode(b []byte, v any, opts ...EncodeOpt) ([]byte, error) {
o := encodeOpts{}
for _, opt := range opts {
opt.applyEncode(&o)
}
t, err := s.encodeFind(o, v)
if err != nil {
return nil, err
}

b = append(b,
Expand Down Expand Up @@ -263,10 +296,34 @@ func (s *Serde) AppendEncode(b []byte, v any) ([]byte, error) {
return append(b, encoded...), nil
}

func (s *Serde) encodeFind(opts encodeOpts, v any) (tserde, error) {
if opts.id > 0 {
// load tserde based on the supplied ID
t := s.loadIDs()[opts.id]
// traverse to the right index, if any is supplied
for _, i := range opts.index {
if len(t.subindex) <= i {
return tserde{}, ErrNotRegistered
}
t = t.subindex[i]
}
if !t.exists || (t.encode == nil && t.appendEncode == nil) {
return tserde{}, ErrNotRegistered
}
return t, nil
}
// get tserde by type
t := s.loadTypes()[reflect.TypeOf(v)]
if !t.exists || (t.encode == nil && t.appendEncode == nil) {
return tserde{}, ErrNotRegistered
}
return t, nil
}

// MustEncode returns the value of Encode, panicking on error. This is a
// shortcut for if your encode function cannot error.
func (s *Serde) MustEncode(v any) []byte {
b, err := s.Encode(v)
func (s *Serde) MustEncode(v any, opts ...EncodeOpt) []byte {
b, err := s.Encode(v, opts...)
if err != nil {
panic(err)
}
Expand All @@ -275,8 +332,8 @@ func (s *Serde) MustEncode(v any) []byte {

// MustAppendEncode returns the value of AppendEncode, panicking on error.
// This is a shortcut for if your encode function cannot error.
func (s *Serde) MustAppendEncode(b []byte, v any) []byte {
b, err := s.AppendEncode(b, v)
func (s *Serde) MustAppendEncode(b []byte, v any, opts ...EncodeOpt) []byte {
b, err := s.AppendEncode(b, v, opts...)
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -309,8 +366,10 @@ func (s *Serde) DecodeNew(b []byte) (any, error) {
var v any
if t.gen != nil {
v = t.gen()
} else {
} else if t.typeof != nil {
v = reflect.New(t.typeof).Interface()
} else {
return nil, ErrNotRegistered
}
return v, t.decode(b, v)
}
Expand Down Expand Up @@ -341,8 +400,8 @@ func (s *Serde) decodeFind(b []byte) ([]byte, tserde, error) {
}
b = r.b
}
if !t.exists {
return nil, t, ErrNotRegistered
if !t.exists || t.decode == nil {
return nil, tserde{}, ErrNotRegistered
}
return b, t, nil
}
Expand Down
28 changes: 28 additions & 0 deletions pkg/sr/serde_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ func TestSerde(t *testing.T) {
serde.Register(3, idx4{}, Index(0, 0, 1))
serde.Register(3, idx3{}, Index(0, 0))
serde.Register(5, oneidx{}, Index(0), GenerateFn(func() any { return &oneidx{Foo: "defoo", Bar: "debar"} }))
serde.Register(100, nil, Index(0), EncodeFn(func(v any) ([]byte, error) {
return json.MarshalIndent(v, "", " ")
}))

for i, test := range []struct {
enc any
Expand Down Expand Up @@ -91,6 +94,10 @@ func TestSerde(t *testing.T) {
expDec: oneidx{Foo: "defoo", Bar: "bar"},
},
} {
if _, err := serde.Encode(test.enc, ID(99)); err != ErrNotRegistered {
t.Errorf("got %v != exp ErrNotRegistered", err)
}

b, err := serde.Encode(test.enc)
gotErr := err != nil
if gotErr != test.expErr {
Expand All @@ -113,6 +120,13 @@ func TestSerde(t *testing.T) {
t.Errorf("#%d got MustAppendEncode(%v) != Encode(foo%v)", i, b2, b)
}

bIndented := serde.MustEncode(test.enc, ID(100), Index(0))
if i := bytes.IndexByte(bIndented, '{'); !bytes.Equal(bIndented[:i], []byte{0, 0, 0, 0, 100, 0}) {
t.Errorf("#%d got Encode[ID=100](%v) != exp(%v)", i, bIndented[:i], []byte{0, 0, 0, 0, 100, 0})
} else if expIndented := extractIndentedJSON(b); !bytes.Equal(bIndented[i:], expIndented) {
t.Errorf("#%d got Encode[ID=100](%v) != exp(%v)", i, bIndented[i:], expIndented)
}

v, err := serde.DecodeNew(b)
if err != nil {
t.Errorf("#%d DecodeNew: got unexpected err %v", i, err)
Expand All @@ -138,4 +152,18 @@ func TestSerde(t *testing.T) {
if _, err := serde.DecodeNew([]byte{0, 0, 0, 0, 99}); err != ErrNotRegistered {
t.Errorf("got %v != exp ErrNotRegistered", err)
}
if _, err := serde.DecodeNew([]byte{0, 0, 0, 0, 100, 0}); err != ErrNotRegistered {
// schema is registered but type is unknown
t.Errorf("got %v != exp ErrNotRegistered", err)
}
}

func extractIndentedJSON(in []byte) []byte {
i := bytes.IndexByte(in, '{') // skip header
var out bytes.Buffer
err := json.Indent(&out, in[i:], "", " ")
if err != nil {
panic(err)
}
return out.Bytes()
}