Skip to content

Commit

Permalink
{channel,server}_test: add tests for message limits.
Browse files Browse the repository at this point in the history
Adjust unit test to accomodate for altered internal interfaces.
Add unit tests to exercise the new message size limit options.

Signed-off-by: Krisztian Litkey <[email protected]>
  • Loading branch information
klihub committed Aug 27, 2024
1 parent 0753662 commit 855d1d7
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 35 deletions.
6 changes: 3 additions & 3 deletions channel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ import (
func TestReadWriteMessage(t *testing.T) {
var (
w, r = net.Pipe()
ch = newChannel(w)
rch = newChannel(r)
ch = newChannel(w, 0)
rch = newChannel(r, 0)
messages = [][]byte{
[]byte("hello"),
[]byte("this is a test"),
Expand Down Expand Up @@ -90,7 +90,7 @@ func TestReadWriteMessage(t *testing.T) {
func TestMessageOversize(t *testing.T) {
var (
w, r = net.Pipe()
wch, rch = newChannel(w), newChannel(r)
wch, rch = newChannel(w, 0), newChannel(r, 0)
msg = bytes.Repeat([]byte("a message of massive length"), 512<<10)
errs = make(chan error, 1)
)
Expand Down
164 changes: 132 additions & 32 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package ttrpc
import (
"bytes"
"context"
"crypto/md5"
"errors"
"fmt"
"net"
Expand Down Expand Up @@ -61,10 +62,17 @@ func (tc *testingClient) Test(ctx context.Context, req *internal.TestPayload) (*
}

// testingServer is what would be implemented by the user of this package.
type testingServer struct{}
type testingServer struct {
echoOnce bool
}

func (s *testingServer) Test(ctx context.Context, req *internal.TestPayload) (*internal.TestPayload, error) {
tp := &internal.TestPayload{Foo: strings.Repeat(req.Foo, 2)}
tp := &internal.TestPayload{}
if s.echoOnce {
tp.Foo = req.Foo
} else {
tp.Foo = strings.Repeat(req.Foo, 2)
}
if dl, ok := ctx.Deadline(); ok {
tp.Deadline = dl.UnixNano()
}
Expand Down Expand Up @@ -299,37 +307,122 @@ func TestServerClose(t *testing.T) {
}

func TestOversizeCall(t *testing.T) {
var (
ctx = context.Background()
server = mustServer(t)(NewServer())
addr, listener = newTestListener(t)
errs = make(chan error, 1)
client, cleanup = newTestClient(t, addr)
)
defer cleanup()
defer listener.Close()
go func() {
errs <- server.Serve(ctx, listener)
}()

registerTestingService(server, &testingServer{})
type testCase struct {
name string
echoOnce bool
clientLimit int
serverLimit int
requestSize int
shouldFail bool
}

runTest := func(t *testing.T, tc *testCase) {
var (
ctx = context.Background()
server = mustServer(t)(NewServer(WithServerWireMessageLimit(tc.serverLimit)))
addr, listener = newTestListener(t)
errs = make(chan error, 1)
client, cleanup = newTestClient(t, addr, WithClientWireMessageLimit(tc.clientLimit))
)
defer cleanup()
defer listener.Close()
go func() {
errs <- server.Serve(ctx, listener)
}()

registerTestingService(server, &testingServer{echoOnce: tc.echoOnce})

req := &internal.TestPayload{
Foo: strings.Repeat("a", tc.requestSize),
}
rsp := &internal.TestPayload{}

err := client.Call(ctx, serviceName, "Test", req, rsp)
if tc.shouldFail {
if err == nil {
t.Fatalf("expected error from oversized message")
} else if status, ok := status.FromError(err); !ok {
t.Fatalf("expected status present in error: %v", err)
} else if status.Code() != codes.ResourceExhausted {
t.Fatalf("expected code: %v != %v", status.Code(), codes.ResourceExhausted)
}
} else {
if err != nil {
t.Fatalf("expected success, got error %v", err)
}
}

tp := &internal.TestPayload{
Foo: strings.Repeat("a", 1+messageLengthMax),
}
if err := client.Call(ctx, serviceName, "Test", tp, tp); err == nil {
t.Fatalf("expected error from oversized message")
} else if status, ok := status.FromError(err); !ok {
t.Fatalf("expected status present in error: %v", err)
} else if status.Code() != codes.ResourceExhausted {
t.Fatalf("expected code: %v != %v", status.Code(), codes.ResourceExhausted)
if err := server.Shutdown(ctx); err != nil {
t.Fatal(err)
}
if err := <-errs; err != ErrServerClosed {
t.Fatal(err)
}
}

if err := server.Shutdown(ctx); err != nil {
t.Fatal(err)
}
if err := <-errs; err != ErrServerClosed {
t.Fatal(err)
// in principle min. marshalled Request{} + messageheaderLength == 29 would be enough
overhead := 32

for _, tc := range []*testCase{
{
name: "default limits, fitting request and response",
echoOnce: true,
clientLimit: 0,
serverLimit: 0,
requestSize: messageLengthMax - overhead,
shouldFail: false,
},
{
name: "default limits, oversized request",
echoOnce: true,
clientLimit: 0,
serverLimit: 0,
requestSize: messageLengthMax,
shouldFail: true,
},
{
name: "default limits, oversized response",
clientLimit: 0,
serverLimit: 0,
requestSize: messageLengthMax / 2,
shouldFail: true,
},
{
name: "8K limits, fitting 4K request and response",
echoOnce: true,
clientLimit: 8 * 1024,
serverLimit: 8 * 1024,
requestSize: 4 * 1024,
shouldFail: false,
},
{
name: "8K client limit, 4K server limit, fitting cc. 4K request and response",
echoOnce: true,
clientLimit: 4 * 1024,
serverLimit: 4 * 1024,
requestSize: 4*1024 - overhead,
shouldFail: false,
},
{
name: "8K client limit, 4K server limit, non-fitting 4K response",
echoOnce: true,
clientLimit: 4 * 1024,
serverLimit: 4 * 1024,
requestSize: 4 * 1024,
shouldFail: true,
},
{
name: "too small limits, adjusted to minimum accepted limit",
echoOnce: true,
clientLimit: 4,
serverLimit: 4,
requestSize: 4*1024 - overhead,
shouldFail: false,
},
} {
t.Run(tc.name, func(t *testing.T) {
runTest(t, tc)
})
}
}

Expand Down Expand Up @@ -551,13 +644,20 @@ func newTestClient(t testing.TB, addr string, opts ...ClientOpts) (*Client, func
}

func newTestListener(t testing.TB) (string, net.Listener) {
var prefix string
var (
name = t.Name()
prefix string
)

// Abstracts sockets are only available on Linux.
if runtime.GOOS == "linux" {
prefix = "\x00"
} else {
if split := strings.SplitN(name, "/", 2); len(split) == 2 {
name = split[0] + "-" + fmt.Sprintf("%x", md5.Sum([]byte(split[1])))
}
}
addr := prefix + t.Name()
addr := prefix + name
listener, err := net.Listen("unix", addr)
if err != nil {
t.Fatal(err)
Expand Down

0 comments on commit 855d1d7

Please sign in to comment.