diff --git a/internal/config/env_test.go b/internal/config/env_test.go index 0134a0e3..601023f1 100644 --- a/internal/config/env_test.go +++ b/internal/config/env_test.go @@ -156,10 +156,13 @@ func TestFileName(t *testing.T) { } } +// keyValueArgs represents a key-value pair for environment variable setup type keyValueArgs struct { key string value string } + +// args holds the setup for environment variables and expected key-value pairs for assertions. type args struct { setEnv []keyValueArgs expectedKeyValues []keyValueArgs @@ -174,7 +177,7 @@ func TestLoadEnv(t *testing.T) { tests := getTestCases(username, host, dbname, password, port) for _, tt := range tests { setEnvironmentVariables(tt.args) - + defer clearEnvironmentVariables(tt.args) t.Run(tt.name, func(t *testing.T) { testLoadEnv(t, tt) }) @@ -348,3 +351,9 @@ func testLoadEnv(t *testing.T, tt struct { } } } + +func clearEnvironmentVariables(args args) { + for _, env := range args.setEnv { + os.Unsetenv(env.key) + } +} diff --git a/internal/middleware/auth/auth_test.go b/internal/middleware/auth/auth_test.go index 0319c25b..b1d9432d 100644 --- a/internal/middleware/auth/auth_test.go +++ b/internal/middleware/auth/auth_test.go @@ -41,16 +41,7 @@ func (s tokenParserMock) ParseToken(token string) (*jwt.Token, error) { var operationHandlerMock func(ctx context.Context) graphql2.ResponseHandler -func TestGraphQLMiddleware(t *testing.T) { - // Define test cases - cases := defineTestCases(t) - _, cleanup, _ := testutls.SetupMockDB(t) - defer cleanup() - // Run test cases - runTestCases(t, cases) -} - -func defineTestCases(t *testing.T) map[string]struct { +type testGraphQLMiddlewareType struct { wantStatus int header string signMethod string @@ -59,17 +50,36 @@ func defineTestCases(t *testing.T) map[string]struct { operationHandler func(ctx context.Context) graphql2.ResponseHandler tokenParser func(token string) (*jwt.Token, error) whiteListedQuery bool -} { - return map[string]struct { - wantStatus int - header string - signMethod string - err string - dbQueries []testutls.QueryData - operationHandler func(ctx context.Context) graphql2.ResponseHandler - tokenParser func(token string) (*jwt.Token, error) - whiteListedQuery bool - }{ + init func(t *testing.T, dbQueries []testutls.QueryData) sqlmock.Sqlmock +} + +func TestGraphQLMiddleware(t *testing.T) { + // Define test cases + cases := defineTestCases(t) + _, cleanup, _ := testutls.SetupMockDB(t) + defer cleanup() + for name, tt := range cases { + t.Run(name, func(t *testing.T) { + mock := tt.init(t, tt.dbQueries) + // Determine request query + requestQuery := testutls.MockQuery + if tt.whiteListedQuery { + requestQuery = testutls.MockWhitelistedQuery + } + + // Make request + makeRequest(t, requestQuery, tt) + + // Ensure mock expectations are met + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("mock expectations were not met: %v", err) + } + }) + } +} + +func defineTestCases(t *testing.T) map[string]testGraphQLMiddlewareType { + return map[string]testGraphQLMiddlewareType{ "SuccessCase": defineSuccessCase(t), "Success__WhitelistedQuery": defineSuccessWhitelistedQuery(), "Failure__NoAuthorizationToken": defineFailureNoAuthorizationToken(), @@ -79,26 +89,8 @@ func defineTestCases(t *testing.T) map[string]struct { } } -func defineSuccessCase(t *testing.T) struct { - wantStatus int - header string - signMethod string - err string - dbQueries []testutls.QueryData - operationHandler func(ctx context.Context) graphql2.ResponseHandler - tokenParser func(token string) (*jwt.Token, error) - whiteListedQuery bool -} { - return struct { - wantStatus int - header string - signMethod string - err string - dbQueries []testutls.QueryData - operationHandler func(ctx context.Context) graphql2.ResponseHandler - tokenParser func(token string) (*jwt.Token, error) - whiteListedQuery bool - }{ +func defineSuccessCase(t *testing.T) testGraphQLMiddlewareType { + return testGraphQLMiddlewareType{ whiteListedQuery: false, header: "Bearer 123", wantStatus: http.StatusOK, @@ -120,29 +112,20 @@ func defineSuccessCase(t *testing.T) struct { ), }, }, + init: func(t *testing.T, dbQueries []testutls.QueryData) sqlmock.Sqlmock { + mock, _, _ := testutls.SetupMockDB(t) + for _, dbQuery := range dbQueries { + mock.ExpectQuery(regexp.QuoteMeta(dbQuery.Query)). + WithArgs(*dbQuery.Actions...). + WillReturnRows(dbQuery.DbResponse) + } + return mock + }, } } -func defineSuccessWhitelistedQuery() struct { - wantStatus int - header string - signMethod string - err string - dbQueries []testutls.QueryData - operationHandler func(ctx context.Context) graphql2.ResponseHandler - tokenParser func(token string) (*jwt.Token, error) - whiteListedQuery bool -} { - return struct { - wantStatus int - header string - signMethod string - err string - dbQueries []testutls.QueryData - operationHandler func(ctx context.Context) graphql2.ResponseHandler - tokenParser func(token string) (*jwt.Token, error) - whiteListedQuery bool - }{ +func defineSuccessWhitelistedQuery() testGraphQLMiddlewareType { + return testGraphQLMiddlewareType{ whiteListedQuery: true, header: "bearer 123", wantStatus: http.StatusOK, @@ -161,29 +144,20 @@ func defineSuccessWhitelistedQuery() struct { return handler }, dbQueries: []testutls.QueryData{}, + init: func(t *testing.T, dbQueries []testutls.QueryData) sqlmock.Sqlmock { + mock, _, _ := testutls.SetupMockDB(t) + for _, dbQuery := range dbQueries { + mock.ExpectQuery(regexp.QuoteMeta(dbQuery.Query)). + WithArgs(*dbQuery.Actions...). + WillReturnRows(dbQuery.DbResponse) + } + return mock + }, } } -func defineFailureNoAuthorizationToken() struct { - wantStatus int - header string - signMethod string - err string - dbQueries []testutls.QueryData - operationHandler func(ctx context.Context) graphql2.ResponseHandler - tokenParser func(token string) (*jwt.Token, error) - whiteListedQuery bool -} { - return struct { - wantStatus int - header string - signMethod string - err string - dbQueries []testutls.QueryData - operationHandler func(ctx context.Context) graphql2.ResponseHandler - tokenParser func(token string) (*jwt.Token, error) - whiteListedQuery bool - }{ +func defineFailureNoAuthorizationToken() testGraphQLMiddlewareType { + return testGraphQLMiddlewareType{ whiteListedQuery: false, header: "", wantStatus: http.StatusOK, @@ -195,29 +169,20 @@ func defineFailureNoAuthorizationToken() struct { return nil }, dbQueries: []testutls.QueryData{}, + init: func(t *testing.T, dbQueries []testutls.QueryData) sqlmock.Sqlmock { + mock, _, _ := testutls.SetupMockDB(t) + for _, dbQuery := range dbQueries { + mock.ExpectQuery(regexp.QuoteMeta(dbQuery.Query)). + WithArgs(*dbQuery.Actions...). + WillReturnRows(dbQuery.DbResponse) + } + return mock + }, } } -func defineFailureNotAnAdmin() struct { - wantStatus int - header string - signMethod string - err string - dbQueries []testutls.QueryData - operationHandler func(ctx context.Context) graphql2.ResponseHandler - tokenParser func(token string) (*jwt.Token, error) - whiteListedQuery bool -} { - return struct { - wantStatus int - header string - signMethod string - err string - dbQueries []testutls.QueryData - operationHandler func(ctx context.Context) graphql2.ResponseHandler - tokenParser func(token string) (*jwt.Token, error) - whiteListedQuery bool - }{ +func defineFailureNotAnAdmin() testGraphQLMiddlewareType { + return testGraphQLMiddlewareType{ whiteListedQuery: false, header: "bearer 123", wantStatus: http.StatusOK, @@ -229,29 +194,20 @@ func defineFailureNotAnAdmin() struct { return nil }, dbQueries: []testutls.QueryData{}, + init: func(t *testing.T, dbQueries []testutls.QueryData) sqlmock.Sqlmock { + mock, _, _ := testutls.SetupMockDB(t) + for _, dbQuery := range dbQueries { + mock.ExpectQuery(regexp.QuoteMeta(dbQuery.Query)). + WithArgs(*dbQuery.Actions...). + WillReturnRows(dbQuery.DbResponse) + } + return mock + }, } } -func defineFailureNoUserWithThatEmail() struct { - wantStatus int - header string - signMethod string - err string - dbQueries []testutls.QueryData - operationHandler func(ctx context.Context) graphql2.ResponseHandler - tokenParser func(token string) (*jwt.Token, error) - whiteListedQuery bool -} { - return struct { - wantStatus int - header string - signMethod string - err string - dbQueries []testutls.QueryData - operationHandler func(ctx context.Context) graphql2.ResponseHandler - tokenParser func(token string) (*jwt.Token, error) - whiteListedQuery bool - }{ +func defineFailureNoUserWithThatEmail() testGraphQLMiddlewareType { + return testGraphQLMiddlewareType{ whiteListedQuery: false, header: "bearer 123", wantStatus: http.StatusOK, @@ -263,29 +219,20 @@ func defineFailureNoUserWithThatEmail() struct { return nil }, dbQueries: []testutls.QueryData{}, + init: func(t *testing.T, dbQueries []testutls.QueryData) sqlmock.Sqlmock { + mock, _, _ := testutls.SetupMockDB(t) + for _, dbQuery := range dbQueries { + mock.ExpectQuery(regexp.QuoteMeta(dbQuery.Query)). + WithArgs(*dbQuery.Actions...). + WillReturnRows(dbQuery.DbResponse) + } + return mock + }, } } -func defineFailureInvalidAuthorizationToken() struct { - wantStatus int - header string - signMethod string - err string - dbQueries []testutls.QueryData - operationHandler func(ctx context.Context) graphql2.ResponseHandler - tokenParser func(token string) (*jwt.Token, error) - whiteListedQuery bool -} { - return struct { - wantStatus int - header string - signMethod string - err string - dbQueries []testutls.QueryData - operationHandler func(ctx context.Context) graphql2.ResponseHandler - tokenParser func(token string) (*jwt.Token, error) - whiteListedQuery bool - }{ +func defineFailureInvalidAuthorizationToken() testGraphQLMiddlewareType { + return testGraphQLMiddlewareType{ whiteListedQuery: false, header: "bearer 123", wantStatus: http.StatusOK, @@ -297,6 +244,15 @@ func defineFailureInvalidAuthorizationToken() struct { return nil }, dbQueries: []testutls.QueryData{}, + init: func(t *testing.T, dbQueries []testutls.QueryData) sqlmock.Sqlmock { + mock, _, _ := testutls.SetupMockDB(t) + for _, dbQuery := range dbQueries { + mock.ExpectQuery(regexp.QuoteMeta(dbQuery.Query)). + WithArgs(*dbQuery.Actions...). + WillReturnRows(dbQuery.DbResponse) + } + return mock + }, } } func defineOperationHandlerSuccessCase(t *testing.T) func(ctx context.Context) graphql2.ResponseHandler { @@ -318,69 +274,7 @@ func defineOperationHandlerSuccessCase(t *testing.T) func(ctx context.Context) g } } -func runTestCases(t *testing.T, cases map[string]struct { - wantStatus int - header string - signMethod string - err string - dbQueries []testutls.QueryData - operationHandler func(ctx context.Context) graphql2.ResponseHandler - tokenParser func(token string) (*jwt.Token, error) - whiteListedQuery bool -}) { - for name, tt := range cases { - t.Run(name, func(t *testing.T) { - runSingleTestCase(t, tt) - }) - } -} - -func runSingleTestCase(t *testing.T, tt struct { - wantStatus int - header string - signMethod string - err string - dbQueries []testutls.QueryData - operationHandler func(ctx context.Context) graphql2.ResponseHandler - tokenParser func(token string) (*jwt.Token, error) - whiteListedQuery bool -}) { - // Set up mock queries - mock := setupMockQueries(t, tt.dbQueries) - - // Determine request query - requestQuery := testutls.MockQuery - if tt.whiteListedQuery { - requestQuery = testutls.MockWhitelistedQuery - } - - // Make request - makeRequest(t, requestQuery, tt) - - // Ensure mock expectations are met - _ = mock.ExpectationsWereMet() -} - -func setupMockQueries(t *testing.T, dbQueries []testutls.QueryData) sqlmock.Sqlmock { - mock, _, _ := testutls.SetupMockDB(t) - for _, dbQuery := range dbQueries { - mock.ExpectQuery(regexp.QuoteMeta(dbQuery.Query)). - WithArgs(*dbQuery.Actions...). - WillReturnRows(dbQuery.DbResponse) - } - return mock -} - -func makeRequest(t *testing.T, requestQuery string, tt struct { - wantStatus int - header string - signMethod string - err string - dbQueries []testutls.QueryData - operationHandler func(ctx context.Context) graphql2.ResponseHandler - tokenParser func(token string) (*jwt.Token, error) - whiteListedQuery bool -}) { +func makeRequest(t *testing.T, requestQuery string, tt testGraphQLMiddlewareType) { // mock token parser to handle the different cases for when the token us valid, invalid, empty parseTokenMock = tt.tokenParser // mock operation handler, and assert different conditions diff --git a/pkg/utl/resultwrapper/error_test.go b/pkg/utl/resultwrapper/error_test.go index ba7e3332..f707b3cd 100644 --- a/pkg/utl/resultwrapper/error_test.go +++ b/pkg/utl/resultwrapper/error_test.go @@ -152,7 +152,6 @@ func TestInternalServerError(t *testing.T) { c echo.Context err error } - errorStr := ErrMsg tests := []struct { name string args args @@ -161,9 +160,9 @@ func TestInternalServerError(t *testing.T) { }{ { name: SuccessCase, - err: errorStr, + err: ErrMsg, args: args{ - err: fmt.Errorf(errorStr), + err: fmt.Errorf(ErrMsg), c: getContext()}, wantErr: true, }, diff --git a/pkg/utl/throttle/throttler_test.go b/pkg/utl/throttle/throttler_test.go index eca31cf2..10267ccf 100644 --- a/pkg/utl/throttle/throttler_test.go +++ b/pkg/utl/throttle/throttler_test.go @@ -129,6 +129,8 @@ func createFailureNotLocalRateLimitExceededTestCase(ctx context.Context) testCas wantErr: true, } } + +// Suggested refactoring to use table-driven tests func TestCheck(t *testing.T) { var ctx context.Context = testutls.MockCtx{} tests := CreateTestCases(ctx) diff --git a/resolver/role_mutations.resolvers_test.go b/resolver/role_mutations.resolvers_test.go index c4b43d57..732fc739 100644 --- a/resolver/role_mutations.resolvers_test.go +++ b/resolver/role_mutations.resolvers_test.go @@ -107,6 +107,8 @@ func successCase() createRoleType { } } +// Suggested refactoring to use table-driven tests and helper functions for common setup. + func errorFromCreateRoleCase() createRoleType { return createRoleType{ name: ErrorFromCreateRole,