From 855d1d767740a0e8b40383c8c46ad0f2f40f7304 Mon Sep 17 00:00:00 2001 From: Krisztian Litkey Date: Mon, 26 Aug 2024 17:20:35 +0300 Subject: [PATCH] {channel,server}_test: add tests for message limits. Adjust unit test to accomodate for altered internal interfaces. Add unit tests to exercise the new message size limit options. Signed-off-by: Krisztian Litkey --- channel_test.go | 6 +- server_test.go | 164 ++++++++++++++++++++++++++++++++++++++---------- 2 files changed, 135 insertions(+), 35 deletions(-) diff --git a/channel_test.go b/channel_test.go index de8b66d38..ec5ebe0ad 100644 --- a/channel_test.go +++ b/channel_test.go @@ -31,8 +31,8 @@ import ( func TestReadWriteMessage(t *testing.T) { var ( w, r = net.Pipe() - ch = newChannel(w) - rch = newChannel(r) + ch = newChannel(w, 0) + rch = newChannel(r, 0) messages = [][]byte{ []byte("hello"), []byte("this is a test"), @@ -90,7 +90,7 @@ func TestReadWriteMessage(t *testing.T) { func TestMessageOversize(t *testing.T) { var ( w, r = net.Pipe() - wch, rch = newChannel(w), newChannel(r) + wch, rch = newChannel(w, 0), newChannel(r, 0) msg = bytes.Repeat([]byte("a message of massive length"), 512<<10) errs = make(chan error, 1) ) diff --git a/server_test.go b/server_test.go index cf34986d6..ed3c6a2c5 100644 --- a/server_test.go +++ b/server_test.go @@ -19,6 +19,7 @@ package ttrpc import ( "bytes" "context" + "crypto/md5" "errors" "fmt" "net" @@ -61,10 +62,17 @@ func (tc *testingClient) Test(ctx context.Context, req *internal.TestPayload) (* } // testingServer is what would be implemented by the user of this package. -type testingServer struct{} +type testingServer struct { + echoOnce bool +} func (s *testingServer) Test(ctx context.Context, req *internal.TestPayload) (*internal.TestPayload, error) { - tp := &internal.TestPayload{Foo: strings.Repeat(req.Foo, 2)} + tp := &internal.TestPayload{} + if s.echoOnce { + tp.Foo = req.Foo + } else { + tp.Foo = strings.Repeat(req.Foo, 2) + } if dl, ok := ctx.Deadline(); ok { tp.Deadline = dl.UnixNano() } @@ -299,37 +307,122 @@ func TestServerClose(t *testing.T) { } func TestOversizeCall(t *testing.T) { - var ( - ctx = context.Background() - server = mustServer(t)(NewServer()) - addr, listener = newTestListener(t) - errs = make(chan error, 1) - client, cleanup = newTestClient(t, addr) - ) - defer cleanup() - defer listener.Close() - go func() { - errs <- server.Serve(ctx, listener) - }() - - registerTestingService(server, &testingServer{}) + type testCase struct { + name string + echoOnce bool + clientLimit int + serverLimit int + requestSize int + shouldFail bool + } + + runTest := func(t *testing.T, tc *testCase) { + var ( + ctx = context.Background() + server = mustServer(t)(NewServer(WithServerWireMessageLimit(tc.serverLimit))) + addr, listener = newTestListener(t) + errs = make(chan error, 1) + client, cleanup = newTestClient(t, addr, WithClientWireMessageLimit(tc.clientLimit)) + ) + defer cleanup() + defer listener.Close() + go func() { + errs <- server.Serve(ctx, listener) + }() + + registerTestingService(server, &testingServer{echoOnce: tc.echoOnce}) + + req := &internal.TestPayload{ + Foo: strings.Repeat("a", tc.requestSize), + } + rsp := &internal.TestPayload{} + + err := client.Call(ctx, serviceName, "Test", req, rsp) + if tc.shouldFail { + if err == nil { + t.Fatalf("expected error from oversized message") + } else if status, ok := status.FromError(err); !ok { + t.Fatalf("expected status present in error: %v", err) + } else if status.Code() != codes.ResourceExhausted { + t.Fatalf("expected code: %v != %v", status.Code(), codes.ResourceExhausted) + } + } else { + if err != nil { + t.Fatalf("expected success, got error %v", err) + } + } - tp := &internal.TestPayload{ - Foo: strings.Repeat("a", 1+messageLengthMax), - } - if err := client.Call(ctx, serviceName, "Test", tp, tp); err == nil { - t.Fatalf("expected error from oversized message") - } else if status, ok := status.FromError(err); !ok { - t.Fatalf("expected status present in error: %v", err) - } else if status.Code() != codes.ResourceExhausted { - t.Fatalf("expected code: %v != %v", status.Code(), codes.ResourceExhausted) + if err := server.Shutdown(ctx); err != nil { + t.Fatal(err) + } + if err := <-errs; err != ErrServerClosed { + t.Fatal(err) + } } - if err := server.Shutdown(ctx); err != nil { - t.Fatal(err) - } - if err := <-errs; err != ErrServerClosed { - t.Fatal(err) + // in principle min. marshalled Request{} + messageheaderLength == 29 would be enough + overhead := 32 + + for _, tc := range []*testCase{ + { + name: "default limits, fitting request and response", + echoOnce: true, + clientLimit: 0, + serverLimit: 0, + requestSize: messageLengthMax - overhead, + shouldFail: false, + }, + { + name: "default limits, oversized request", + echoOnce: true, + clientLimit: 0, + serverLimit: 0, + requestSize: messageLengthMax, + shouldFail: true, + }, + { + name: "default limits, oversized response", + clientLimit: 0, + serverLimit: 0, + requestSize: messageLengthMax / 2, + shouldFail: true, + }, + { + name: "8K limits, fitting 4K request and response", + echoOnce: true, + clientLimit: 8 * 1024, + serverLimit: 8 * 1024, + requestSize: 4 * 1024, + shouldFail: false, + }, + { + name: "8K client limit, 4K server limit, fitting cc. 4K request and response", + echoOnce: true, + clientLimit: 4 * 1024, + serverLimit: 4 * 1024, + requestSize: 4*1024 - overhead, + shouldFail: false, + }, + { + name: "8K client limit, 4K server limit, non-fitting 4K response", + echoOnce: true, + clientLimit: 4 * 1024, + serverLimit: 4 * 1024, + requestSize: 4 * 1024, + shouldFail: true, + }, + { + name: "too small limits, adjusted to minimum accepted limit", + echoOnce: true, + clientLimit: 4, + serverLimit: 4, + requestSize: 4*1024 - overhead, + shouldFail: false, + }, + } { + t.Run(tc.name, func(t *testing.T) { + runTest(t, tc) + }) } } @@ -551,13 +644,20 @@ func newTestClient(t testing.TB, addr string, opts ...ClientOpts) (*Client, func } func newTestListener(t testing.TB) (string, net.Listener) { - var prefix string + var ( + name = t.Name() + prefix string + ) // Abstracts sockets are only available on Linux. if runtime.GOOS == "linux" { prefix = "\x00" + } else { + if split := strings.SplitN(name, "/", 2); len(split) == 2 { + name = split[0] + "-" + fmt.Sprintf("%x", md5.Sum([]byte(split[1]))) + } } - addr := prefix + t.Name() + addr := prefix + name listener, err := net.Listen("unix", addr) if err != nil { t.Fatal(err)