diff --git a/pkg/rpc/codec.go b/pkg/rpc/codec.go index cc215968e..38da7c50a 100644 --- a/pkg/rpc/codec.go +++ b/pkg/rpc/codec.go @@ -1,35 +1,62 @@ package rpc import ( - gogoproto "github.com/gogo/protobuf/proto" - "github.com/golang/protobuf/proto" //nolint:staticcheck "google.golang.org/grpc/encoding" + "google.golang.org/grpc/mem" ) const name = "proto" -type codec struct{} +type gogoprotoMessage interface { + MarshalToSizedBuffer([]byte) (int, error) + Unmarshal([]byte) error + ProtoSize() int +} + +var pool = mem.DefaultBufferPool() + +type codec struct { + fallback encoding.CodecV2 +} -var _ encoding.Codec = codec{} +var _ encoding.CodecV2 = &codec{} func init() { - encoding.RegisterCodec(codec{}) + encoding.RegisterCodecV2(&codec{ + fallback: encoding.GetCodecV2(name), + }) } -func (codec) Marshal(v interface{}) ([]byte, error) { - if m, ok := v.(gogoproto.Marshaler); ok { - return m.Marshal() +func (c *codec) Marshal(v any) (mem.BufferSlice, error) { + if m, ok := v.(gogoprotoMessage); ok { + size := m.ProtoSize() + if mem.IsBelowBufferPoolingThreshold(size) { + buf := make([]byte, size) + if _, err := m.MarshalToSizedBuffer(buf[:size]); err != nil { + return nil, err + } + return mem.BufferSlice{mem.SliceBuffer(buf)}, nil + } + + buf := pool.Get(size) + if _, err := m.MarshalToSizedBuffer((*buf)[:size]); err != nil { + pool.Put(buf) + return nil, err + } + return mem.BufferSlice{mem.NewBuffer(buf, pool)}, nil } - return proto.Marshal(v.(proto.Message)) + return c.fallback.Marshal(v) } -func (codec) Unmarshal(data []byte, v interface{}) error { - if m, ok := v.(gogoproto.Unmarshaler); ok { - return m.Unmarshal(data) +func (c *codec) Unmarshal(data mem.BufferSlice, v any) error { + if m, ok := v.(gogoprotoMessage); ok { + buf := data.MaterializeToBuffer(pool) + defer buf.Free() + return m.Unmarshal(buf.ReadOnlyData()) } - return proto.Unmarshal(data, v.(proto.Message)) + return c.fallback.Unmarshal(data, v) } -func (codec) Name() string { +func (*codec) Name() string { return name }