Skip to content

Commit

Permalink
{channel,server}_test: add tests for message limits.
Browse files Browse the repository at this point in the history
Adjust unit test to accomodate for altered internal interfaces.
Add unit tests to exercise the new message size limit options.

Signed-off-by: Krisztian Litkey <[email protected]>
  • Loading branch information
klihub committed Sep 12, 2024
1 parent 53a5b5a commit 6e89eda
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 32 deletions.
6 changes: 3 additions & 3 deletions channel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ import (
func TestReadWriteMessage(t *testing.T) {
var (
w, r = net.Pipe()
ch = newChannel(w)
rch = newChannel(r)
ch = newChannel(w, 0)
rch = newChannel(r, 0)
messages = [][]byte{
[]byte("hello"),
[]byte("this is a test"),
Expand Down Expand Up @@ -90,7 +90,7 @@ func TestReadWriteMessage(t *testing.T) {
func TestMessageOversize(t *testing.T) {
var (
w, _ = net.Pipe()
wch = newChannel(w)
wch = newChannel(w, 0)
msg = bytes.Repeat([]byte("a message of massive length"), 512<<10)
errs = make(chan error, 1)
)
Expand Down
188 changes: 159 additions & 29 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package ttrpc
import (
"bytes"
"context"
"crypto/md5"
"errors"
"fmt"
"net"
Expand Down Expand Up @@ -61,10 +62,17 @@ func (tc *testingClient) Test(ctx context.Context, req *internal.TestPayload) (*
}

// testingServer is what would be implemented by the user of this package.
type testingServer struct{}
type testingServer struct {
echoOnce bool
}

func (s *testingServer) Test(ctx context.Context, req *internal.TestPayload) (*internal.TestPayload, error) {
tp := &internal.TestPayload{Foo: strings.Repeat(req.Foo, 2)}
tp := &internal.TestPayload{}
if s.echoOnce {
tp.Foo = req.Foo
} else {
tp.Foo = strings.Repeat(req.Foo, 2)
}
if dl, ok := ctx.Deadline(); ok {
tp.Deadline = dl.UnixNano()
}
Expand Down Expand Up @@ -299,37 +307,152 @@ func TestServerClose(t *testing.T) {
}

func TestOversizeCall(t *testing.T) {
var (
ctx = context.Background()
server = mustServer(t)(NewServer())
addr, listener = newTestListener(t)
errs = make(chan error, 1)
client, cleanup = newTestClient(t, addr)
)
defer cleanup()
defer listener.Close()
go func() {
errs <- server.Serve(ctx, listener)
}()
type testCase struct {
name string
echoOnce bool
clientLimit int
serverLimit int
requestSize int
clientFail bool
serverFail bool
}

overhead := getWireMessageOverhead(t)

runTest := func(t *testing.T, tc *testCase) {
var (
ctx = context.Background()
server = mustServer(t)(NewServer(WithServerWireMessageLimit(tc.serverLimit)))
addr, listener = newTestListener(t)
errs = make(chan error, 1)
client, cleanup = newTestClient(t, addr, WithClientWireMessageLimit(tc.clientLimit))
)
defer cleanup()
defer listener.Close()
go func() {
errs <- server.Serve(ctx, listener)
}()

registerTestingService(server, &testingServer{echoOnce: tc.echoOnce})

req := &internal.TestPayload{
Foo: strings.Repeat("a", tc.requestSize),
}
rsp := &internal.TestPayload{}

err := client.Call(ctx, serviceName, "Test", req, rsp)
if tc.clientFail {
if err == nil {
t.Fatalf("expected error from oversized message")
} else if status, ok := status.FromError(err); !ok {
t.Fatalf("expected status present in error: %v", err)
} else if status.Code() != codes.ResourceExhausted {
t.Fatalf("expected code: %v != %v", status.Code(), codes.ResourceExhausted)
}
} else if tc.serverFail {
if err == nil {
t.Fatalf("expected error from server-side oversized message")
}
} else {
if err != nil {
t.Fatalf("expected success, got error %v", err)
}
}

registerTestingService(server, &testingServer{})
if err := server.Shutdown(ctx); err != nil {
t.Fatal(err)
}
if err := <-errs; err != ErrServerClosed {
t.Fatal(err)
}
}

tp := &internal.TestPayload{
Foo: strings.Repeat("a", 1+messageLengthMax),
for _, tc := range []*testCase{
{
name: "default limits, fitting request and response",
echoOnce: true,
clientLimit: 0,
serverLimit: 0,
requestSize: DefaultMessageLengthLimit - overhead,
},
{
name: "default limits, oversized request",
echoOnce: true,
clientLimit: 0,
serverLimit: 0,
requestSize: DefaultMessageLengthLimit,
clientFail: true,
},
{
name: "default limits, oversized response",
clientLimit: 0,
serverLimit: 0,
requestSize: DefaultMessageLengthLimit / 2,
serverFail: true,
},
{
name: "8K limits, fitting 4K request and response",
echoOnce: true,
clientLimit: 8 * 1024,
serverLimit: 8 * 1024,
requestSize: 4 * 1024,
},
{
name: "8K limits, fitting cc. 4K request and response",
echoOnce: true,
clientLimit: 4 * 1024,
serverLimit: 4 * 1024,
requestSize: 4*1024 - overhead,
},
{
name: "4K limits, non-fitting 4K response",
echoOnce: true,
clientLimit: 4*1024 + overhead,
serverLimit: 4 * 1024,
requestSize: 4 * 1024,
serverFail: true,
},
{
name: "too small limits, adjusted to minimum accepted limit",
echoOnce: true,
clientLimit: 4,
serverLimit: 4,
requestSize: 4*1024 - overhead,
},
{
name: "maximum allowed protocol limit",
echoOnce: true,
clientLimit: MaxMessageLengthLimit,
serverLimit: MaxMessageLengthLimit,
requestSize: MaxMessageLengthLimit - overhead,
},
} {
t.Run(tc.name, func(t *testing.T) {
runTest(t, tc)
})
}
if err := client.Call(ctx, serviceName, "Test", tp, tp); err == nil {
t.Fatalf("expected error from oversized message")
} else if status, ok := status.FromError(err); !ok {
t.Fatalf("expected status present in error: %v", err)
} else if status.Code() != codes.ResourceExhausted {
t.Fatalf("expected code: %v != %v", status.Code(), codes.ResourceExhausted)
}

func getWireMessageOverhead(t *testing.T) int {
emptyReq, err := codec{}.Marshal(&Request{
Service: serviceName,
Method: "Test",
})
if err != nil {
t.Fatalf("failed to marshal empty request: %v", err)
}

if err := server.Shutdown(ctx); err != nil {
t.Fatal(err)
emptyRsp, err := codec{}.Marshal(&Response{
Status: status.New(codes.OK, "").Proto(),
})
if err != nil {
t.Fatalf("failed to marshal empty response: %v", err)
}
if err := <-errs; err != ErrServerClosed {
t.Fatal(err)

if reqLen, rspLen := len(emptyReq), len(emptyRsp); reqLen > rspLen {
return reqLen + messageHeaderLength
} else {
return rspLen + messageHeaderLength
}
}

Expand Down Expand Up @@ -551,13 +674,20 @@ func newTestClient(t testing.TB, addr string, opts ...ClientOpts) (*Client, func
}

func newTestListener(t testing.TB) (string, net.Listener) {
var prefix string
var (
name = t.Name()
prefix string
)

// Abstracts sockets are only available on Linux.
if runtime.GOOS == "linux" {
prefix = "\x00"
} else {
if split := strings.SplitN(name, "/", 2); len(split) == 2 {
name = split[0] + "-" + fmt.Sprintf("%x", md5.Sum([]byte(split[1])))
}
}
addr := prefix + t.Name()
addr := prefix + name
listener, err := net.Listen("unix", addr)
if err != nil {
t.Fatal(err)
Expand Down

0 comments on commit 6e89eda

Please sign in to comment.