diff --git a/go/vt/servenv/grpc_codec.go b/go/vt/servenv/grpc_codec.go index 4376783de20..7d2b6364d3b 100644 --- a/go/vt/servenv/grpc_codec.go +++ b/go/vt/servenv/grpc_codec.go @@ -38,29 +38,25 @@ type vtprotoMessage interface { } func (vtprotoCodec) Marshal(v any) ([]byte, error) { - vt, ok := v.(vtprotoMessage) - if ok { - return vt.MarshalVT() + switch v := v.(type) { + case vtprotoMessage: + return v.MarshalVT() + case proto.Message: + return proto.Marshal(v) + default: + return nil, fmt.Errorf("failed to marshal, message is %T, must satisfy the vtprotoMessage interface or want proto.Message", v) } - - vv, ok := v.(proto.Message) - if !ok { - return nil, fmt.Errorf("failed to marshal, message is %T, want proto.Message", v) - } - return proto.Marshal(vv) } func (vtprotoCodec) Unmarshal(data []byte, v any) error { - vt, ok := v.(vtprotoMessage) - if ok { - return vt.UnmarshalVT(data) - } - - vv, ok := v.(proto.Message) - if !ok { - return fmt.Errorf("failed to unmarshal, message is %T, want proto.Message", v) + switch v := v.(type) { + case vtprotoMessage: + return v.UnmarshalVT(data) + case proto.Message: + return proto.Unmarshal(data, v) + default: + return fmt.Errorf("failed to unmarshal, message is %T, must satisfy the vtprotoMessage interface or want proto.Message", v) } - return proto.Unmarshal(data, vv) } func (vtprotoCodec) Name() string {