diff --git a/client_test.go b/client_test.go index 80f59d5..15e7e8c 100644 --- a/client_test.go +++ b/client_test.go @@ -28,11 +28,9 @@ func TestClient(t *testing.T) { t.Fatal(err) } - dispatchtest.AssertDispatchRequests(t, recorder.Requests, []dispatchtest.DispatchRequest{ - { - ApiKey: "foobar", - Calls: []dispatch.Call{call}, - }, + recorder.Assert(t, dispatchtest.DispatchRequest{ + ApiKey: "foobar", + Calls: []dispatch.Call{call}, }) } @@ -58,11 +56,9 @@ func TestClientEnvConfig(t *testing.T) { t.Fatal(err) } - dispatchtest.AssertDispatchRequests(t, recorder.Requests, []dispatchtest.DispatchRequest{ - { - ApiKey: "foobar", - Calls: []dispatch.Call{call}, - }, + recorder.Assert(t, dispatchtest.DispatchRequest{ + ApiKey: "foobar", + Calls: []dispatch.Call{call}, }) } @@ -107,16 +103,15 @@ func TestClientBatch(t *testing.T) { t.Fatal(err) } - dispatchtest.AssertDispatchRequests(t, recorder.Requests, []dispatchtest.DispatchRequest{ - { + recorder.Assert(t, + dispatchtest.DispatchRequest{ ApiKey: "foobar", Calls: []dispatch.Call{call1, call2}, }, - { + dispatchtest.DispatchRequest{ ApiKey: "foobar", Calls: []dispatch.Call{call3, call4}, - }, - }) + }) } func TestClientNoAPIKey(t *testing.T) { diff --git a/dispatch_test.go b/dispatch_test.go index db35dd7..f180302 100644 --- a/dispatch_test.go +++ b/dispatch_test.go @@ -115,11 +115,9 @@ func TestDispatchCall(t *testing.T) { t.Fatal(err) } - dispatchtest.AssertDispatchRequests(t, recorder.Requests, []dispatchtest.DispatchRequest{ - { - ApiKey: "foobar", - Calls: []dispatch.Call{wantCall}, - }, + recorder.Assert(t, dispatchtest.DispatchRequest{ + ApiKey: "foobar", + Calls: []dispatch.Call{wantCall}, }) } @@ -151,11 +149,9 @@ func TestDispatchCallEnvConfig(t *testing.T) { t.Fatal(err) } - dispatchtest.AssertDispatchRequests(t, recorder.Requests, []dispatchtest.DispatchRequest{ - { - ApiKey: "foobar", - Calls: []dispatch.Call{wantCall}, - }, + recorder.Assert(t, dispatchtest.DispatchRequest{ + ApiKey: "foobar", + Calls: []dispatch.Call{wantCall}, }) } @@ -201,11 +197,9 @@ func TestDispatchCallsBatch(t *testing.T) { t.Fatal(err) } - dispatchtest.AssertDispatchRequests(t, recorder.Requests, []dispatchtest.DispatchRequest{ - { - ApiKey: "foobar", - Calls: []dispatch.Call{call1, call2}, - }, + recorder.Assert(t, dispatchtest.DispatchRequest{ + ApiKey: "foobar", + Calls: []dispatch.Call{call1, call2}, }) } diff --git a/dispatchtest/assert.go b/dispatchtest/assert.go deleted file mode 100644 index a49666c..0000000 --- a/dispatchtest/assert.go +++ /dev/null @@ -1,42 +0,0 @@ -package dispatchtest - -import ( - "testing" - - "github.com/dispatchrun/dispatch-go" -) - -func AssertCalls(t *testing.T, got, want []dispatch.Call) { - t.Helper() - - if len(got) != len(want) { - t.Fatalf("unexpected number of calls: got %v, want %v", len(got), len(want)) - } - for i, call := range got { - if !call.Equal(want[i]) { - t.Errorf("unexpected call %d: got %v, want %v", i, call, want[i]) - } - } -} - -func AssertCall(t *testing.T, got, want dispatch.Call) { - t.Helper() - - if !got.Equal(want) { - t.Errorf("unexpected call: got %v, want %v", got, want) - } -} - -func AssertDispatchRequests(t *testing.T, got, want []DispatchRequest) { - t.Helper() - - if len(got) != len(want) { - t.Fatalf("unexpected number of requests: got %v, want %v", len(got), len(want)) - } - for i, req := range got { - if req.ApiKey != want[i].ApiKey { - t.Errorf("unexpected API key on request %d: got %v, want %v", i, req.ApiKey, want[i].ApiKey) - } - AssertCalls(t, req.Calls, want[i].Calls) - } -} diff --git a/dispatchtest/server.go b/dispatchtest/server.go index 7590155..09b4896 100644 --- a/dispatchtest/server.go +++ b/dispatchtest/server.go @@ -7,6 +7,7 @@ import ( "net/http/httptest" "strconv" "strings" + "testing" "buf.build/gen/go/stealthrocket/dispatch-proto/connectrpc/go/dispatch/sdk/v1/sdkv1connect" sdkv1 "buf.build/gen/go/stealthrocket/dispatch-proto/protocolbuffers/go/dispatch/sdk/v1" @@ -82,7 +83,7 @@ func wrapCall(c *sdkv1.Call) (dispatch.Call, error) { // CallRecorder is a DispatchServerHandler that captures requests to the Dispatch API. type CallRecorder struct { - Requests []DispatchRequest + requests []DispatchRequest calls int } @@ -96,7 +97,7 @@ func (r *CallRecorder) Handle(ctx context.Context, apiKey string, calls []dispat base := r.calls r.calls += len(calls) - r.Requests = append(r.Requests, DispatchRequest{ + r.requests = append(r.requests, DispatchRequest{ ApiKey: apiKey, Calls: calls, }) @@ -107,3 +108,26 @@ func (r *CallRecorder) Handle(ctx context.Context, apiKey string, calls []dispat } return ids, nil } + +func (r *CallRecorder) Assert(t *testing.T, want ...DispatchRequest) { + t.Helper() + + got := r.requests + if len(got) != len(want) { + t.Fatalf("unexpected number of requests: got %v, want %v", len(got), len(want)) + } + for i, req := range got { + if req.ApiKey != want[i].ApiKey { + t.Errorf("unexpected API key on request %d: got %v, want %v", i, req.ApiKey, want[i].ApiKey) + } + if len(req.Calls) != len(want[i].Calls) { + t.Errorf("unexpected number of calls in request %d: got %v, want %v", i, len(req.Calls), len(want[i].Calls)) + } else { + for j, call := range req.Calls { + if !call.Equal(want[i].Calls[j]) { + t.Errorf("unexpected request %d call %d: got %v, want %v", i, j, call, want[i].Calls[j]) + } + } + } + } +}