diff --git a/go.mod b/go.mod index adb74474f6c..8e7a90f89c8 100644 --- a/go.mod +++ b/go.mod @@ -18,7 +18,7 @@ require ( github.com/fsnotify/fsnotify v1.7.0 github.com/go-sql-driver/mysql v1.7.1 github.com/golang/glog v1.2.2 - github.com/golang/protobuf v1.5.4 + github.com/golang/protobuf v1.5.4 // indirect github.com/golang/snappy v0.0.4 github.com/google/go-cmp v0.6.0 github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 diff --git a/go/vt/servenv/grpc_codec.go b/go/vt/servenv/grpc_codec.go index 7d2b6364d3b..35441feb261 100644 --- a/go/vt/servenv/grpc_codec.go +++ b/go/vt/servenv/grpc_codec.go @@ -17,52 +17,64 @@ limitations under the License. package servenv import ( - "fmt" - - // use the original golang/protobuf package we can continue serializing - // messages from our dependencies, particularly from etcd - "github.com/golang/protobuf/proto" //nolint - "google.golang.org/grpc/encoding" + "google.golang.org/grpc/mem" + + // Guarantee that the built-in proto is called registered before this one + // so that it can be replaced. _ "google.golang.org/grpc/encoding/proto" // nolint:revive ) // Name is the name registered for the proto compressor. const Name = "proto" -type vtprotoCodec struct{} - type vtprotoMessage interface { - MarshalVT() ([]byte, error) + MarshalToSizedBufferVT(data []byte) (int, error) UnmarshalVT([]byte) error + SizeVT() int } -func (vtprotoCodec) Marshal(v any) ([]byte, error) { - switch v := v.(type) { - case vtprotoMessage: - return v.MarshalVT() - case proto.Message: - return proto.Marshal(v) - default: - return nil, fmt.Errorf("failed to marshal, message is %T, must satisfy the vtprotoMessage interface or want proto.Message", v) - } +type Codec struct { + fallback encoding.CodecV2 } -func (vtprotoCodec) Unmarshal(data []byte, v any) error { - switch v := v.(type) { - case vtprotoMessage: - return v.UnmarshalVT(data) - case proto.Message: - return proto.Unmarshal(data, v) - default: - return fmt.Errorf("failed to unmarshal, message is %T, must satisfy the vtprotoMessage interface or want proto.Message", v) +func (Codec) Name() string { return Name } + +var defaultBufferPool = mem.DefaultBufferPool() + +func (c *Codec) Marshal(v any) (mem.BufferSlice, error) { + if m, ok := v.(vtprotoMessage); ok { + size := m.SizeVT() + if mem.IsBelowBufferPoolingThreshold(size) { + buf := make([]byte, size) + if _, err := m.MarshalToSizedBufferVT(buf[:size]); err != nil { + return nil, err + } + return mem.BufferSlice{mem.SliceBuffer(buf)}, nil + } + buf := defaultBufferPool.Get(size) + if _, err := m.MarshalToSizedBufferVT((*buf)[:size]); err != nil { + defaultBufferPool.Put(buf) + return nil, err + } + return mem.BufferSlice{mem.NewBuffer(buf, defaultBufferPool)}, nil } + + return c.fallback.Marshal(v) } -func (vtprotoCodec) Name() string { - return Name +func (c *Codec) Unmarshal(data mem.BufferSlice, v any) error { + if m, ok := v.(vtprotoMessage); ok { + buf := data.MaterializeToBuffer(defaultBufferPool) + defer buf.Free() + return m.UnmarshalVT(buf.ReadOnlyData()) + } + + return c.fallback.Unmarshal(data, v) } func init() { - encoding.RegisterCodec(vtprotoCodec{}) + encoding.RegisterCodecV2(&Codec{ + fallback: encoding.GetCodecV2("proto"), + }) }