diff --git a/pkg/sr/api.go b/pkg/sr/api.go index eb50b521..fa6fd9f1 100644 --- a/pkg/sr/api.go +++ b/pkg/sr/api.go @@ -216,7 +216,7 @@ func (cl *Client) SubjectsByID(ctx context.Context, id int) ([]string, error) { return subjects, err } -// SchemaVersion is a subject version pair. +// SubjectVersion is a subject version pair. type SubjectVersion struct { Subject string `json:"subject"` Version int `json:"version"` @@ -602,7 +602,7 @@ type SetCompatibility struct { OverrideRuleSet *SchemaRuleSet `json:"overrideRuleSet,omitempty"` // Override rule set used for schema registration. } -// SetCompatibilitysets the compatibility for each requested subject. The +// SetCompatibility sets the compatibility for each requested subject. The // global compatibility can be set by either using an empty subject or by // specifying no subjects. If specifying no subjects, this returns one element. func (cl *Client) SetCompatibility(ctx context.Context, compat SetCompatibility, subjects ...string) []CompatibilityResult { diff --git a/pkg/sr/client.go b/pkg/sr/client.go index 0a53618c..39e52306 100644 --- a/pkg/sr/client.go +++ b/pkg/sr/client.go @@ -63,7 +63,7 @@ type Client struct { } // NewClient returns a new schema registry client. -func NewClient(opts ...Opt) (*Client, error) { +func NewClient(opts ...ClientOpt) (*Client, error) { cl := &Client{ urls: []string{"http://localhost:8081"}, httpcl: &http.Client{Timeout: 5 * time.Second}, diff --git a/pkg/sr/config.go b/pkg/sr/clientopt.go similarity index 69% rename from pkg/sr/config.go rename to pkg/sr/clientopt.go index 499e0ab6..b8d04090 100644 --- a/pkg/sr/config.go +++ b/pkg/sr/clientopt.go @@ -9,30 +9,30 @@ import ( ) type ( - // Opt is an option to configure a client. - Opt interface{ apply(*Client) } - opt struct{ fn func(*Client) } + // ClientOpt is an option to configure a client. + ClientOpt interface{ apply(*Client) } + clientOpt struct{ fn func(*Client) } ) -func (o opt) apply(cl *Client) { o.fn(cl) } +func (o clientOpt) apply(cl *Client) { o.fn(cl) } // HTTPClient sets the http client that the schema registry client uses, // overriding the default client that speaks plaintext with a timeout of 5s. -func HTTPClient(httpcl *http.Client) Opt { - return opt{func(cl *Client) { cl.httpcl = httpcl }} +func HTTPClient(httpcl *http.Client) ClientOpt { + return clientOpt{func(cl *Client) { cl.httpcl = httpcl }} } // UserAgent sets the User-Agent to use in requests, overriding the default // "franz-go". -func UserAgent(ua string) Opt { - return opt{func(cl *Client) { cl.ua = ua }} +func UserAgent(ua string) ClientOpt { + return clientOpt{func(cl *Client) { cl.ua = ua }} } // URLs sets the URLs that the client speaks to, overriding the default // http://localhost:8081. This option automatically prefixes any URL that is // missing an http:// or https:// prefix with http://. -func URLs(urls ...string) Opt { - return opt{func(cl *Client) { +func URLs(urls ...string) ClientOpt { + return clientOpt{func(cl *Client) { for i, u := range urls { if strings.HasPrefix(u, "http://") || strings.HasPrefix(u, "https://") { continue @@ -45,8 +45,8 @@ func URLs(urls ...string) Opt { } // DialTLSConfig sets a tls.Config to use in the default http client. -func DialTLSConfig(c *tls.Config) Opt { - return opt{func(cl *Client) { +func DialTLSConfig(c *tls.Config) ClientOpt { + return clientOpt{func(cl *Client) { cl.httpcl = &http.Client{ Timeout: 5 * time.Second, Transport: &http.Transport{ @@ -68,8 +68,8 @@ func DialTLSConfig(c *tls.Config) Opt { } // BasicAuth sets basic authorization to use for every request. -func BasicAuth(user, pass string) Opt { - return opt{func(cl *Client) { +func BasicAuth(user, pass string) ClientOpt { + return clientOpt{func(cl *Client) { cl.basicAuth = &struct { user string pass string @@ -78,8 +78,8 @@ func BasicAuth(user, pass string) Opt { } // DefaultParams sets default parameters to apply to every request. -func DefaultParams(ps ...Param) Opt { - return opt{func(cl *Client) { +func DefaultParams(ps ...Param) ClientOpt { + return clientOpt{func(cl *Client) { cl.defParams = mergeParams(ps...) }} } diff --git a/pkg/sr/enums.go b/pkg/sr/enums.go index 2c1b1ca7..f37c1038 100644 --- a/pkg/sr/enums.go +++ b/pkg/sr/enums.go @@ -197,7 +197,7 @@ func (k *SchemaRuleKind) UnmarshalText(text []byte) error { return nil } -// Mode specifies a schema rule's mode. +// SchemaRuleMode specifies a schema rule's mode. // // Migration rules can be specified for an UPGRADE, DOWNGRADE, or both // (UPDOWN). Migration rules are used during complex schema evolution. diff --git a/pkg/sr/serde.go b/pkg/sr/serde.go index 909701f8..9593946d 100644 --- a/pkg/sr/serde.go +++ b/pkg/sr/serde.go @@ -21,33 +21,53 @@ var ( ) type ( - // SerdeOpt is an option to configure a Serde. - SerdeOpt interface{ apply(*tserde) } - serdeOpt struct{ fn func(*tserde) } + // EncodingOpt is an option to configure the behavior of Serde.Encode and + // Serde.Decode. + EncodingOpt interface { + serdeOrEncodingOpt() + apply(*tserde) + } + encodingOpt struct{ fn func(*tserde) } + + // SerdeOpt is an option to configure Serde. + SerdeOpt interface { + serdeOrEncodingOpt() + apply(*Serde) + } + serdeOpt struct{ fn func(serde *Serde) } + + // SerdeOrEncodingOpt is either a SerdeOpt or EncodingOpt. + SerdeOrEncodingOpt interface { + serdeOrEncodingOpt() + } ) -func (o serdeOpt) apply(t *tserde) { o.fn(t) } +func (o serdeOpt) serdeOrEncodingOpt() { /* satisfy interface */ } +func (o serdeOpt) apply(s *Serde) { o.fn(s) } + +func (o encodingOpt) serdeOrEncodingOpt() { /* satisfy interface */ } +func (o encodingOpt) apply(t *tserde) { o.fn(t) } // EncodeFn allows Serde to encode a value. -func EncodeFn(fn func(any) ([]byte, error)) SerdeOpt { - return serdeOpt{func(t *tserde) { t.encode = fn }} +func EncodeFn(fn func(any) ([]byte, error)) EncodingOpt { + return encodingOpt{func(t *tserde) { t.encode = fn }} } // AppendEncodeFn allows Serde to encode a value to an existing slice. This // can be more efficient than EncodeFn; this function is used if it exists. -func AppendEncodeFn(fn func([]byte, any) ([]byte, error)) SerdeOpt { - return serdeOpt{func(t *tserde) { t.appendEncode = fn }} +func AppendEncodeFn(fn func([]byte, any) ([]byte, error)) EncodingOpt { + return encodingOpt{func(t *tserde) { t.appendEncode = fn }} } // DecodeFn allows Serde to decode into a value. -func DecodeFn(fn func([]byte, any) error) SerdeOpt { - return serdeOpt{func(t *tserde) { t.decode = fn }} +func DecodeFn(fn func([]byte, any) error) EncodingOpt { + return encodingOpt{func(t *tserde) { t.decode = fn }} } // GenerateFn returns a new(Value) that can be decoded into. This function can // be used to control the instantiation of a new type for DecodeNew. -func GenerateFn(fn func() any) SerdeOpt { - return serdeOpt{func(t *tserde) { t.gen = fn }} +func GenerateFn(fn func() any) EncodingOpt { + return encodingOpt{func(t *tserde) { t.gen = fn }} } // Index attaches a message index to a value. A single schema ID can be @@ -62,8 +82,13 @@ func GenerateFn(fn func() any) SerdeOpt { // For more information, see where `message-indexes` are described in: // // https://docs.confluent.io/platform/current/schema-registry/serdes-develop/index.html#wire-format -func Index(index ...int) SerdeOpt { - return serdeOpt{func(t *tserde) { t.index = index }} +func Index(index ...int) EncodingOpt { + return encodingOpt{func(t *tserde) { t.index = index }} +} + +// Header defines the SerdeHeader used to encode and decode the message header. +func Header(header SerdeHeader) SerdeOpt { + return serdeOpt{func(s *Serde) { s.h = header }} } type tserde struct { @@ -96,7 +121,7 @@ type Serde struct { types atomic.Value // map[reflect.Type]tserde mu sync.Mutex - defaults []SerdeOpt + defaults []EncodingOpt h SerdeHeader } @@ -105,6 +130,25 @@ var ( noTypes = make(map[reflect.Type]tserde) ) +// NewSerde returns a new Serde using the supplied default options, which are +// applied to every registered type. These options are always applied first, so +// you can override them as necessary when registering. +// +// This can be useful if you always want to use the same encoding or decoding +// functions. +func NewSerde(opts ...SerdeOrEncodingOpt) *Serde { + var s Serde + for _, opt := range opts { + switch opt := opt.(type) { + case SerdeOpt: + opt.apply(&s) + case EncodingOpt: + s.defaults = append(s.defaults, opt) + } + } + return &s +} + func (s *Serde) loadIDs() map[int]tserde { ids := s.ids.Load() if ids == nil { @@ -121,16 +165,6 @@ func (s *Serde) loadTypes() map[reflect.Type]tserde { return types.(map[reflect.Type]tserde) } -// SetDefaults sets default options to apply to every registered type. These -// options are always applied first, so you can override them as necessary when -// registering. -// -// This can be useful if you always want to use the same encoding or decoding -// functions. -func (s *Serde) SetDefaults(opts ...SerdeOpt) { - s.defaults = opts -} - // DecodeID decodes an ID from b, returning the ID and the remaining bytes, // or an error. func (s *Serde) DecodeID(b []byte) (id int, out []byte, err error) { @@ -154,7 +188,7 @@ func (s *Serde) header() SerdeHeader { // Register registers a schema ID and the value it corresponds to, as well as // the encoding or decoding functions. You need to register functions depending // on whether you are only encoding, only decoding, or both. -func (s *Serde) Register(id int, v any, opts ...SerdeOpt) { +func (s *Serde) Register(id int, v any, opts ...EncodingOpt) { var t tserde for _, opt := range s.defaults { opt.apply(&t) @@ -258,20 +292,18 @@ func (s *Serde) Encode(v any) ([]byte, error) { return s.AppendEncode(nil, v) } -// AppendEncode appends an encoded value to b according to the schema registry -// wire format and returns it. If EncodeFn was not used, this returns -// ErrNotRegistered. +// AppendEncode encodes a value and prepends the header according to the +// configured SerdeHeader, appends it to b and returns b. If EncodeFn was not +// registered, this returns ErrNotRegistered. func (s *Serde) AppendEncode(b []byte, v any) ([]byte, error) { t, ok := s.loadTypes()[reflect.TypeOf(v)] if !ok || (t.encode == nil && t.appendEncode == nil) { return b, ErrNotRegistered } - b, err := s.header().AppendEncode(b, int(t.id), t.index) if err != nil { return nil, err } - if t.appendEncode != nil { return t.appendEncode(b, v) } @@ -328,8 +360,10 @@ func (s *Serde) DecodeNew(b []byte) (any, error) { var v any if t.gen != nil { v = t.gen() - } else { + } else if t.typeof != nil { v = reflect.New(t.typeof).Interface() + } else { + return nil, ErrNotRegistered } return v, t.decode(b, v) } @@ -339,7 +373,6 @@ func (s *Serde) decodeFind(b []byte) ([]byte, tserde, error) { if err != nil { return nil, tserde{}, err } - t := s.loadIDs()[id] if len(t.subindex) > 0 { var index []int @@ -360,6 +393,28 @@ func (s *Serde) decodeFind(b []byte) ([]byte, tserde, error) { return b, t, nil } +// Encode encodes a value and prepends the header. If the encoding function +// fails, this returns an error. +func Encode(v any, id int, index []int, h SerdeHeader, enc func(any) ([]byte, error)) ([]byte, error) { + return AppendEncode(nil, v, id, index, h, func(b []byte, val any) ([]byte, error) { + encoded, err := enc(val) + if err != nil { + return nil, err + } + return append(b, encoded...), nil + }) +} + +// AppendEncode encodes a value and prepends the header, appends it to b and +// returns b. If the encoding function fails, this returns an error. +func AppendEncode(b []byte, v any, id int, index []int, h SerdeHeader, enc func([]byte, any) ([]byte, error)) ([]byte, error) { + b, err := h.AppendEncode(b, id, index) + if err != nil { + return nil, err + } + return enc(b, v) +} + // SerdeHeader encodes and decodes a message header. type SerdeHeader interface { // AppendEncode encodes a schema ID and optional index to b, returning the diff --git a/pkg/sr/serde_test.go b/pkg/sr/serde_test.go index 4b591fa0..931768cb 100644 --- a/pkg/sr/serde_test.go +++ b/pkg/sr/serde_test.go @@ -36,8 +36,7 @@ func TestSerde(t *testing.T) { } ) - var serde Serde - serde.SetDefaults( + serde := NewSerde( EncodeFn(json.Marshal), DecodeFn(json.Unmarshal), ) @@ -113,6 +112,19 @@ func TestSerde(t *testing.T) { t.Errorf("#%d got MustAppendEncode(%v) != Encode(foo%v)", i, b2, b) } + bIndented, err := Encode(test.enc, 100, []int{0}, serde.header(), func(v any) ([]byte, error) { + return json.MarshalIndent(v, "", " ") + }) + if err != nil { + t.Errorf("#%d Encode[ID=100]: got err? %v, exp err? %v", i, gotErr, test.expErr) + continue + } + if i := bytes.IndexByte(bIndented, '{'); !bytes.Equal(bIndented[:i], []byte{0, 0, 0, 0, 100, 0}) { + t.Errorf("#%d got Encode[ID=100](%v) != exp(%v)", i, bIndented[:i], []byte{0, 0, 0, 0, 100, 0}) + } else if expIndented := extractIndentedJSON(b); !bytes.Equal(bIndented[i:], expIndented) { + t.Errorf("#%d got Encode[ID=100](%v) != exp(%v)", i, bIndented[i:], expIndented) + } + v, err := serde.DecodeNew(b) if err != nil { t.Errorf("#%d DecodeNew: got unexpected err %v", i, err) @@ -126,6 +138,7 @@ func TestSerde(t *testing.T) { } if !reflect.DeepEqual(v, exp) { t.Errorf("#%d round trip: got %v != exp %v", i, v, exp) + continue } } @@ -141,6 +154,20 @@ func TestSerde(t *testing.T) { if _, err := serde.DecodeNew([]byte{0, 0, 0, 0, 99}); err != ErrNotRegistered { t.Errorf("got %v != exp ErrNotRegistered", err) } + if _, err := serde.DecodeNew([]byte{0, 0, 0, 0, 100, 0}); err != ErrNotRegistered { + // schema is registered but type is unknown + t.Errorf("got %v != exp ErrNotRegistered", err) + } +} + +func extractIndentedJSON(in []byte) []byte { + i := bytes.IndexByte(in, '{') // skip header + var out bytes.Buffer + err := json.Indent(&out, in[i:], "", " ") + if err != nil { + panic(err) + } + return out.Bytes() } func TestConfluentHeader(t *testing.T) {