From dc692fa98fa36b65ce24e37dbe10222ce117fef0 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Vicent=20Mart=C3=AD?= <42793+vmg@users.noreply.github.com>
Date: Thu, 17 Oct 2024 15:21:27 +0200
Subject: [PATCH] grpc: upgrade to 1.66.2 and use Codec v2 (#16790)

Signed-off-by: Vicent Marti <vmg@strn.cat>
Signed-off-by: Matt Lord <mattalord@gmail.com>
Co-authored-by: Matt Lord <mattalord@gmail.com>
---
 go.mod                      |  2 +-
 go/vt/servenv/grpc_codec.go | 70 ++++++++++++++++++++++---------------
 2 files changed, 42 insertions(+), 30 deletions(-)

diff --git a/go.mod b/go.mod
index adb74474f6c..8e7a90f89c8 100644
--- a/go.mod
+++ b/go.mod
@@ -18,7 +18,7 @@ require (
 	github.com/fsnotify/fsnotify v1.7.0
 	github.com/go-sql-driver/mysql v1.7.1
 	github.com/golang/glog v1.2.2
-	github.com/golang/protobuf v1.5.4
+	github.com/golang/protobuf v1.5.4 // indirect
 	github.com/golang/snappy v0.0.4
 	github.com/google/go-cmp v0.6.0
 	github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510
diff --git a/go/vt/servenv/grpc_codec.go b/go/vt/servenv/grpc_codec.go
index 7d2b6364d3b..35441feb261 100644
--- a/go/vt/servenv/grpc_codec.go
+++ b/go/vt/servenv/grpc_codec.go
@@ -17,52 +17,64 @@ limitations under the License.
 package servenv
 
 import (
-	"fmt"
-
-	// use the original golang/protobuf package we can continue serializing
-	// messages from our dependencies, particularly from etcd
-	"github.com/golang/protobuf/proto" //nolint
-
 	"google.golang.org/grpc/encoding"
+	"google.golang.org/grpc/mem"
+
+	// Guarantee that the built-in proto is called registered before this one
+	// so that it can be replaced.
 	_ "google.golang.org/grpc/encoding/proto" // nolint:revive
 )
 
 // Name is the name registered for the proto compressor.
 const Name = "proto"
 
-type vtprotoCodec struct{}
-
 type vtprotoMessage interface {
-	MarshalVT() ([]byte, error)
+	MarshalToSizedBufferVT(data []byte) (int, error)
 	UnmarshalVT([]byte) error
+	SizeVT() int
 }
 
-func (vtprotoCodec) Marshal(v any) ([]byte, error) {
-	switch v := v.(type) {
-	case vtprotoMessage:
-		return v.MarshalVT()
-	case proto.Message:
-		return proto.Marshal(v)
-	default:
-		return nil, fmt.Errorf("failed to marshal, message is %T, must satisfy the vtprotoMessage interface or want proto.Message", v)
-	}
+type Codec struct {
+	fallback encoding.CodecV2
 }
 
-func (vtprotoCodec) Unmarshal(data []byte, v any) error {
-	switch v := v.(type) {
-	case vtprotoMessage:
-		return v.UnmarshalVT(data)
-	case proto.Message:
-		return proto.Unmarshal(data, v)
-	default:
-		return fmt.Errorf("failed to unmarshal, message is %T, must satisfy the vtprotoMessage interface or want proto.Message", v)
+func (Codec) Name() string { return Name }
+
+var defaultBufferPool = mem.DefaultBufferPool()
+
+func (c *Codec) Marshal(v any) (mem.BufferSlice, error) {
+	if m, ok := v.(vtprotoMessage); ok {
+		size := m.SizeVT()
+		if mem.IsBelowBufferPoolingThreshold(size) {
+			buf := make([]byte, size)
+			if _, err := m.MarshalToSizedBufferVT(buf[:size]); err != nil {
+				return nil, err
+			}
+			return mem.BufferSlice{mem.SliceBuffer(buf)}, nil
+		}
+		buf := defaultBufferPool.Get(size)
+		if _, err := m.MarshalToSizedBufferVT((*buf)[:size]); err != nil {
+			defaultBufferPool.Put(buf)
+			return nil, err
+		}
+		return mem.BufferSlice{mem.NewBuffer(buf, defaultBufferPool)}, nil
 	}
+
+	return c.fallback.Marshal(v)
 }
 
-func (vtprotoCodec) Name() string {
-	return Name
+func (c *Codec) Unmarshal(data mem.BufferSlice, v any) error {
+	if m, ok := v.(vtprotoMessage); ok {
+		buf := data.MaterializeToBuffer(defaultBufferPool)
+		defer buf.Free()
+		return m.UnmarshalVT(buf.ReadOnlyData())
+	}
+
+	return c.fallback.Unmarshal(data, v)
 }
 
 func init() {
-	encoding.RegisterCodec(vtprotoCodec{})
+	encoding.RegisterCodecV2(&Codec{
+		fallback: encoding.GetCodecV2("proto"),
+	})
 }