Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

client, server: implement configurable wire message size limits. #172

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 43 additions & 18 deletions channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,13 @@ import (
"io"
"net"
"sync"

"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

const (
messageHeaderLength = 10
messageLengthMax = 4 << 20
messageHeaderLength = 10
MinMessageLengthLimit = 4 << 10
MaxMessageLengthLimit = 4 << 22
DefaultMessageLengthLimit = 4 << 20
)

type messageType uint8
Expand Down Expand Up @@ -96,18 +95,23 @@ func writeMessageHeader(w io.Writer, p []byte, mh messageHeader) error {
var buffers sync.Pool

type channel struct {
conn net.Conn
bw *bufio.Writer
br *bufio.Reader
hrbuf [messageHeaderLength]byte // avoid alloc when reading header
hwbuf [messageHeaderLength]byte
conn net.Conn
bw *bufio.Writer
br *bufio.Reader
hrbuf [messageHeaderLength]byte // avoid alloc when reading header
hwbuf [messageHeaderLength]byte
maxMsgLen int
}

func newChannel(conn net.Conn) *channel {
func newChannel(conn net.Conn, maxMsgLen int) *channel {
if maxMsgLen == 0 {
maxMsgLen = DefaultMessageLengthLimit
}
return &channel{
conn: conn,
bw: bufio.NewWriter(conn),
br: bufio.NewReader(conn),
conn: conn,
bw: bufio.NewWriter(conn),
br: bufio.NewReader(conn),
maxMsgLen: maxMsgLen,
}
}

Expand All @@ -123,12 +127,12 @@ func (ch *channel) recv() (messageHeader, []byte, error) {
return messageHeader{}, nil, err
}

if mh.Length > uint32(messageLengthMax) {
if maxMsgLen := ch.maxMsgLimit(true); mh.Length > uint32(maxMsgLen) {
if _, err := ch.br.Discard(int(mh.Length)); err != nil {
return mh, nil, fmt.Errorf("failed to discard after receiving oversized message: %w", err)
}

return mh, nil, status.Errorf(codes.ResourceExhausted, "message length %v exceed maximum message size of %v", mh.Length, messageLengthMax)
return mh, nil, OversizedMessageError(int(mh.Length), maxMsgLen)
}

var p []byte
Expand All @@ -143,8 +147,10 @@ func (ch *channel) recv() (messageHeader, []byte, error) {
}

func (ch *channel) send(streamID uint32, t messageType, flags uint8, p []byte) error {
if len(p) > messageLengthMax {
return OversizedMessageError(len(p))
if maxMsgLen := ch.maxMsgLimit(false); maxMsgLen != 0 {
if len(p) > maxMsgLen {
return OversizedMessageError(len(p), maxMsgLen)
}
}

if err := writeMessageHeader(ch.bw, ch.hwbuf[:], messageHeader{Length: uint32(len(p)), StreamID: streamID, Type: t, Flags: flags}); err != nil {
Expand Down Expand Up @@ -180,3 +186,22 @@ func (ch *channel) getmbuf(size int) []byte {
func (ch *channel) putmbuf(p []byte) {
buffers.Put(&p)
}

func (ch *channel) maxMsgLimit(recv bool) int {
if ch.maxMsgLen == 0 && recv {
return DefaultMessageLengthLimit
}
return ch.maxMsgLen
}

func clampWireMessageLimit(maxMsgLen int) int {
switch {
case maxMsgLen == 0:
return 0
case maxMsgLen < MinMessageLengthLimit:
return MinMessageLengthLimit
case maxMsgLen > MaxMessageLengthLimit:
return MaxMessageLengthLimit
}
return maxMsgLen
}
6 changes: 3 additions & 3 deletions channel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ import (
func TestReadWriteMessage(t *testing.T) {
var (
w, r = net.Pipe()
ch = newChannel(w)
rch = newChannel(r)
ch = newChannel(w, 0)
rch = newChannel(r, 0)
messages = [][]byte{
[]byte("hello"),
[]byte("this is a test"),
Expand Down Expand Up @@ -90,7 +90,7 @@ func TestReadWriteMessage(t *testing.T) {
func TestMessageOversize(t *testing.T) {
var (
w, _ = net.Pipe()
wch = newChannel(w)
wch = newChannel(w, 0)
msg = bytes.Repeat([]byte("a message of massive length"), 512<<10)
errs = make(chan error, 1)
)
Expand Down
19 changes: 14 additions & 5 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@ import (

// Client for a ttrpc server
type Client struct {
codec codec
conn net.Conn
channel *channel
codec codec
conn net.Conn
channel *channel
maxMsgLen int

streamLock sync.RWMutex
streams map[streamID]*stream
Expand Down Expand Up @@ -107,14 +108,20 @@ func chainUnaryInterceptors(interceptors []UnaryClientInterceptor, final Invoker
}
}

// WithClientWireMessageLimit sets the maximum allowed message length on the wire for the client.
func WithClientWireMessageLimit(maxMsgLen int) ClientOpts {
maxMsgLen = clampWireMessageLimit(maxMsgLen)
return func(c *Client) {
c.maxMsgLen = maxMsgLen
}
}

// NewClient creates a new ttrpc client using the given connection
func NewClient(conn net.Conn, opts ...ClientOpts) *Client {
ctx, cancel := context.WithCancel(context.Background())
channel := newChannel(conn)
c := &Client{
codec: codec{},
conn: conn,
channel: channel,
streams: make(map[streamID]*stream),
nextStreamID: 1,
closed: cancel,
Expand All @@ -127,6 +134,8 @@ func NewClient(conn net.Conn, opts ...ClientOpts) *Client {
o(c)
}

c.channel = newChannel(conn, c.maxMsgLen)

if c.interceptor == nil {
c.interceptor = defaultClientInterceptor
}
Expand Down
10 changes: 10 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
type serverConfig struct {
handshaker Handshaker
interceptor UnaryServerInterceptor
maxMsgLen int
}

// ServerOpt for configuring a ttrpc server
Expand Down Expand Up @@ -84,3 +85,12 @@ func chainUnaryServerInterceptors(info *UnaryServerInfo, method Method, intercep
chainUnaryServerInterceptors(info, method, interceptors[1:]))
}
}

// WithServerWireMessageLimit sets the maximum allowed message length on the wire for the server.
func WithServerWireMessageLimit(maxMsgLen int) ServerOpt {
maxMsgLen = clampWireMessageLimit(maxMsgLen)
return func(c *serverConfig) error {
c.maxMsgLen = maxMsgLen
return nil
}
}
50 changes: 45 additions & 5 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package ttrpc

import (
"errors"
"fmt"

"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
Expand All @@ -43,20 +44,59 @@ var (
// length.
type OversizedMessageErr struct {
messageLength int
maxLength int
err error
}

var (
oversizedMsgFmt = "message length %d exceeds maximum message size of %d"
oversizedMsgScanFmt = fmt.Sprintf("%v", status.New(codes.ResourceExhausted, oversizedMsgFmt))
)

// OversizedMessageError returns an OversizedMessageErr error for the given message
// length if it exceeds the allowed maximum. Otherwise a nil error is returned.
func OversizedMessageError(messageLength int) error {
if messageLength <= messageLengthMax {
func OversizedMessageError(messageLength, maxLength int) error {
if messageLength <= maxLength {
return nil
}

return &OversizedMessageErr{
messageLength: messageLength,
err: status.Errorf(codes.ResourceExhausted, "message length %v exceed maximum message size of %v", messageLength, messageLengthMax),
maxLength: maxLength,
err: OversizedMessageStatus(messageLength, maxLength).Err(),
}
}

// OversizedMessageStatus returns a Status for an oversized message error.
func OversizedMessageStatus(messageLength, maxLength int) *status.Status {
return status.Newf(codes.ResourceExhausted, oversizedMsgFmt, messageLength, maxLength)
}

// OversizedMessageFromError reconstructs an OversizedMessageErr from a Status.
func OversizedMessageFromError(err error) (*OversizedMessageErr, bool) {
var (
messageLength int
maxLength int
)

st, ok := status.FromError(err)
if !ok || st.Code() != codes.ResourceExhausted {
return nil, false
}

// TODO(klihub): might be too ugly to recover an error this way... An
// alternative would be to define our custom status detail proto type,
// then use status.WithDetails() and status.Details().

n, _ := fmt.Sscanf(st.Message(), oversizedMsgScanFmt, &messageLength, &maxLength)
if n != 2 {
n, _ = fmt.Sscanf(st.Message(), oversizedMsgFmt, &messageLength, &maxLength)
}
if n != 2 {
return nil, false
}

return OversizedMessageError(messageLength, maxLength).(*OversizedMessageErr), true
}

// Error returns the error message for the corresponding grpc Status for the error.
Expand All @@ -75,6 +115,6 @@ func (e *OversizedMessageErr) RejectedLength() int {
}

// MaximumLength retrieves the maximum allowed message length that triggered the error.
func (*OversizedMessageErr) MaximumLength() int {
return messageLengthMax
func (e *OversizedMessageErr) MaximumLength() int {
return e.maxLength
}
21 changes: 20 additions & 1 deletion server.go
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ func (c *serverConn) run(sctx context.Context) {
)

var (
ch = newChannel(c.conn)
ch = newChannel(c.conn, c.server.config.maxMsgLen)
ctx, cancel = context.WithCancel(sctx)
state connState = connStateIdle
responses = make(chan response)
Expand Down Expand Up @@ -373,6 +373,14 @@ func (c *serverConn) run(sctx context.Context) {
}
}

isResourceExhaustedError := func(err error) (*status.Status, bool) {
st, ok := status.FromError(err)
if !ok || st.Code() != codes.ResourceExhausted {
return nil, false
}
return st, true
}

go func(recvErr chan error) {
defer close(recvErr)
for {
Expand Down Expand Up @@ -525,6 +533,17 @@ func (c *serverConn) run(sctx context.Context) {
}

if err := ch.send(response.id, messageTypeResponse, 0, p); err != nil {
if st, ok := isResourceExhaustedError(err); ok {
p, err = c.server.codec.Marshal(&Response{
Status: st.Proto(),
})
if err != nil {
log.G(ctx).WithError(err).Error("failed marshaling error response")
return
}
ch.send(response.id, messageTypeResponse, 0, p)
return
}
log.G(ctx).WithError(err).Error("failed sending message on channel")
return
}
Expand Down
Loading
Loading