diff --git a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured-hook=new.json b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured-hook=new.json index 3748c3744f..e1381a9044 100644 --- a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured-hook=new.json +++ b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured-hook=new.json @@ -14,7 +14,8 @@ "amr": null, "c_hash": "", "ext": { - "hooked": "legacy" + "hooked": "legacy", + "sid": "" } }, "headers": { diff --git a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured-hook=new.json b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured-hook=new.json index 3748c3744f..e1381a9044 100644 --- a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured-hook=new.json +++ b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured-hook=new.json @@ -14,7 +14,8 @@ "amr": null, "c_hash": "", "ext": { - "hooked": "legacy" + "hooked": "legacy", + "sid": "" } }, "headers": { diff --git a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured-hook=new.json b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured-hook=new.json index 3748c3744f..e1381a9044 100644 --- a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured-hook=new.json +++ b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured-hook=new.json @@ -14,7 +14,8 @@ "amr": null, "c_hash": "", "ext": { - "hooked": "legacy" + "hooked": "legacy", + "sid": "" } }, "headers": { diff --git a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured-hook=new.json b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured-hook=new.json index 3748c3744f..e1381a9044 100644 --- a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured-hook=new.json +++ b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured-hook=new.json @@ -14,7 +14,8 @@ "amr": null, "c_hash": "", "ext": { - "hooked": "legacy" + "hooked": "legacy", + "sid": "" } }, "headers": { diff --git a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured-hook=new.json b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured-hook=new.json index 3748c3744f..e1381a9044 100644 --- a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured-hook=new.json +++ b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured-hook=new.json @@ -14,7 +14,8 @@ "amr": null, "c_hash": "", "ext": { - "hooked": "legacy" + "hooked": "legacy", + "sid": "" } }, "headers": { diff --git a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured-hook=new.json b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured-hook=new.json index 3748c3744f..e1381a9044 100644 --- a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured-hook=new.json +++ b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured-hook=new.json @@ -14,7 +14,8 @@ "amr": null, "c_hash": "", "ext": { - "hooked": "legacy" + "hooked": "legacy", + "sid": "" } }, "headers": { diff --git a/oauth2/token_hook.go b/oauth2/token_hook.go index d32cadd7e4..83c7491aad 100644 --- a/oauth2/token_hook.go +++ b/oauth2/token_hook.go @@ -143,13 +143,26 @@ func executeHookAndUpdateSession(ctx context.Context, reg x.HTTPClientProvider, ) } - // Overwrite existing session data (extra claims). - session.Extra = respBody.Session.AccessToken + // Update existing session data (extra claims). + session.Extra = updateExtraClaims(session.Extra, respBody.Session.AccessToken) idTokenClaims := session.IDTokenClaims() - idTokenClaims.Extra = respBody.Session.IDToken + idTokenClaims.Extra = updateExtraClaims(idTokenClaims.Extra, respBody.Session.IDToken) return nil } +func updateExtraClaims(claimsToUpdate, webhookExtraClaims map[string]interface{}) map[string]interface{} { + if webhookExtraClaims == nil { + return claimsToUpdate + } + if claimsToUpdate == nil { + claimsToUpdate = make(map[string]interface{}) + } + for key, value := range webhookExtraClaims { + claimsToUpdate[key] = value + } + return claimsToUpdate +} + // TokenHook is an AccessRequestHook called for all grant types. func TokenHook(reg interface { config.Provider diff --git a/oauth2/token_hook_test.go b/oauth2/token_hook_test.go new file mode 100644 index 0000000000..d6c3dd24d6 --- /dev/null +++ b/oauth2/token_hook_test.go @@ -0,0 +1,111 @@ +// Copyright © 2024 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package oauth2 + +import ( + "reflect" + "testing" +) + +func TestUpdateExtraClaims(t *testing.T) { + tests := []struct { + name string + priorExtraClaims map[string]interface{} + webhookExtraClaims map[string]interface{} + expected map[string]interface{} + }{ + { + name: "Merge with no updates", + priorExtraClaims: map[string]interface{}{ + "claim1": "value1", + "claim2": "value2", + }, + webhookExtraClaims: map[string]interface{}{ + "claim3": "value3", + "claim4": "value4", + }, + expected: map[string]interface{}{ + "claim1": "value1", + "claim2": "value2", + "claim3": "value3", + "claim4": "value4", + }, + }, + { + name: "Merge with updates", + priorExtraClaims: map[string]interface{}{ + "claim1": "value1", + "claim2": "value2", + }, + webhookExtraClaims: map[string]interface{}{ + "claim2": "newValue2", // Overwrites prior claim2 + "claim3": "value3", + }, + expected: map[string]interface{}{ + "claim1": "value1", + "claim2": "newValue2", + "claim3": "value3", + }, + }, + { + name: "Empty webhook claims", + priorExtraClaims: map[string]interface{}{ + "claim1": "value1", + }, + webhookExtraClaims: map[string]interface{}{}, + expected: map[string]interface{}{ + "claim1": "value1", + }, + }, + { + name: "Empty prior claims", + priorExtraClaims: map[string]interface{}{}, + webhookExtraClaims: map[string]interface{}{ + "claim1": "value1", + }, + expected: map[string]interface{}{ + "claim1": "value1", + }, + }, + { + name: "Both maps empty", + priorExtraClaims: map[string]interface{}{}, + webhookExtraClaims: map[string]interface{}{}, + expected: map[string]interface{}{}, + }, + { + name: "Nil webhook claims", + priorExtraClaims: map[string]interface{}{"claim1": "value1"}, + webhookExtraClaims: nil, + expected: map[string]interface{}{"claim1": "value1"}, + }, + { + name: "Nil prior claims", + priorExtraClaims: nil, + webhookExtraClaims: map[string]interface{}{"claim1": "value1"}, + expected: map[string]interface{}{"claim1": "value1"}, + }, + { + name: "Both maps nil", + priorExtraClaims: nil, + webhookExtraClaims: nil, + expected: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Act + if tt.priorExtraClaims == nil { + tt.priorExtraClaims = nil // Explicitly ensure nil for this test case + } + actual := updateExtraClaims(tt.priorExtraClaims, tt.webhookExtraClaims) + + // Assert + if !reflect.DeepEqual(actual, tt.expected) { + t.Errorf("claimsToUpdate = %v, want %v", tt.priorExtraClaims, tt.expected) + } + }) + } +}