diff --git a/oauth2/token_hook.go b/oauth2/token_hook.go index 28ce149bf3..83c7491aad 100644 --- a/oauth2/token_hook.go +++ b/oauth2/token_hook.go @@ -144,16 +144,23 @@ func executeHookAndUpdateSession(ctx context.Context, reg x.HTTPClientProvider, } // Update existing session data (extra claims). - updateExtraClaims(session.Extra, respBody.Session.AccessToken) + session.Extra = updateExtraClaims(session.Extra, respBody.Session.AccessToken) idTokenClaims := session.IDTokenClaims() - updateExtraClaims(idTokenClaims.Extra, respBody.Session.IDToken) + idTokenClaims.Extra = updateExtraClaims(idTokenClaims.Extra, respBody.Session.IDToken) return nil } -func updateExtraClaims(claimsToUpdate, webhookExtraClaims map[string]interface{}) { +func updateExtraClaims(claimsToUpdate, webhookExtraClaims map[string]interface{}) map[string]interface{} { + if webhookExtraClaims == nil { + return claimsToUpdate + } + if claimsToUpdate == nil { + claimsToUpdate = make(map[string]interface{}) + } for key, value := range webhookExtraClaims { claimsToUpdate[key] = value } + return claimsToUpdate } // TokenHook is an AccessRequestHook called for all grant types. diff --git a/oauth2/token_hook_test.go b/oauth2/token_hook_test.go index a313badc0f..d6c3dd24d6 100644 --- a/oauth2/token_hook_test.go +++ b/oauth2/token_hook_test.go @@ -74,15 +74,36 @@ func TestUpdateExtraClaims(t *testing.T) { webhookExtraClaims: map[string]interface{}{}, expected: map[string]interface{}{}, }, + { + name: "Nil webhook claims", + priorExtraClaims: map[string]interface{}{"claim1": "value1"}, + webhookExtraClaims: nil, + expected: map[string]interface{}{"claim1": "value1"}, + }, + { + name: "Nil prior claims", + priorExtraClaims: nil, + webhookExtraClaims: map[string]interface{}{"claim1": "value1"}, + expected: map[string]interface{}{"claim1": "value1"}, + }, + { + name: "Both maps nil", + priorExtraClaims: nil, + webhookExtraClaims: nil, + expected: nil, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // Act - updateExtraClaims(tt.priorExtraClaims, tt.webhookExtraClaims) + if tt.priorExtraClaims == nil { + tt.priorExtraClaims = nil // Explicitly ensure nil for this test case + } + actual := updateExtraClaims(tt.priorExtraClaims, tt.webhookExtraClaims) // Assert - if !reflect.DeepEqual(tt.priorExtraClaims, tt.expected) { + if !reflect.DeepEqual(actual, tt.expected) { t.Errorf("claimsToUpdate = %v, want %v", tt.priorExtraClaims, tt.expected) } })