Skip to content

Commit

Permalink
Merge pull request #2158 from anjuchamantha/query-callback-with-fragment
Browse files Browse the repository at this point in the history
Preserve already existing query params in callback url when fragment response mode is used.
  • Loading branch information
janakamarasena authored Oct 9, 2023
2 parents 5250b12 + 89aab77 commit 83f4cc0
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 167 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -552,31 +552,7 @@ public void testAuthorize(Object flowStatusObject, String[] clientId, String ses
when(IdentityDatabaseUtil.getDBConnection()).thenReturn(connection);
mockServiceURLBuilder();
try {
Map<String, ResponseModeProvider> supportedResponseModeProviders = new HashMap<>();
ResponseModeProvider defaultResponseModeProvider;
Map<String, String> supportedResponseModeClassNames = new HashMap<>();
String defaultResponseModeProviderClassName;
supportedResponseModeClassNames.put(OAuthConstants.ResponseModes.QUERY,
QueryResponseModeProvider.class.getCanonicalName());
supportedResponseModeClassNames.put(OAuthConstants.ResponseModes.FRAGMENT,
FragmentResponseModeProvider.class.getCanonicalName());
supportedResponseModeClassNames.put(OAuthConstants.ResponseModes.FORM_POST,
FormPostResponseModeProvider.class.getCanonicalName());
defaultResponseModeProviderClassName = DefaultResponseModeProvider.class.getCanonicalName();

for (Map.Entry<String, String> entry : supportedResponseModeClassNames.entrySet()) {
ResponseModeProvider responseModeProvider = (ResponseModeProvider)
Class.forName(entry.getValue()).newInstance();

supportedResponseModeProviders.put(entry.getKey(), responseModeProvider);
}

defaultResponseModeProvider = (ResponseModeProvider)
Class.forName(defaultResponseModeProviderClassName).newInstance();

OAuth2ServiceComponentHolder.setResponseModeProviders(supportedResponseModeProviders);
OAuth2ServiceComponentHolder.setDefaultResponseModeProvider(defaultResponseModeProvider);

setSupportedResponseModes();
response = oAuth2AuthzEndpoint.authorize(httpServletRequest, httpServletResponse);
} catch (InvalidRequestParentException ire) {
InvalidRequestExceptionMapper invalidRequestExceptionMapper = new InvalidRequestExceptionMapper();
Expand Down Expand Up @@ -762,32 +738,7 @@ public void testAuthorizeForAuthenticationResponse(boolean isResultInRequest, bo
anyString(), isNull(), anyInt(), anyList())).thenReturn(true);

mockServiceURLBuilder();

Map<String, ResponseModeProvider> supportedResponseModeProviders = new HashMap<>();
ResponseModeProvider defaultResponseModeProvider;
Map<String, String> supportedResponseModeClassNames = new HashMap<>();
String defaultResponseModeProviderClassName;
supportedResponseModeClassNames.put(OAuthConstants.ResponseModes.QUERY,
QueryResponseModeProvider.class.getCanonicalName());
supportedResponseModeClassNames.put(OAuthConstants.ResponseModes.FRAGMENT,
FragmentResponseModeProvider.class.getCanonicalName());
supportedResponseModeClassNames.put(OAuthConstants.ResponseModes.FORM_POST,
FormPostResponseModeProvider.class.getCanonicalName());
defaultResponseModeProviderClassName = DefaultResponseModeProvider.class.getCanonicalName();

for (Map.Entry<String, String> entry : supportedResponseModeClassNames.entrySet()) {
ResponseModeProvider responseModeProvider = (ResponseModeProvider)
Class.forName(entry.getValue()).newInstance();

supportedResponseModeProviders.put(entry.getKey(), responseModeProvider);
}

defaultResponseModeProvider = (ResponseModeProvider)
Class.forName(defaultResponseModeProviderClassName).newInstance();

OAuth2ServiceComponentHolder.setResponseModeProviders(supportedResponseModeProviders);
OAuth2ServiceComponentHolder.setDefaultResponseModeProvider(defaultResponseModeProvider);

setSupportedResponseModes();
Response response = oAuth2AuthzEndpoint.authorize(httpServletRequest, httpServletResponse);
assertEquals(response.getStatus(), expected, "Unexpected HTTP response status");
if (!isAuthenticated) {
Expand Down Expand Up @@ -958,31 +909,7 @@ public void testUserConsentResponse(String consent, String redirectUrl, Set<Stri

Response response;
try {
Map<String, ResponseModeProvider> supportedResponseModeProviders = new HashMap<>();
ResponseModeProvider defaultResponseModeProvider;
Map<String, String> supportedResponseModeClassNames = new HashMap<>();
String defaultResponseModeProviderClassName;
supportedResponseModeClassNames.put(OAuthConstants.ResponseModes.QUERY,
QueryResponseModeProvider.class.getCanonicalName());
supportedResponseModeClassNames.put(OAuthConstants.ResponseModes.FRAGMENT,
FragmentResponseModeProvider.class.getCanonicalName());
supportedResponseModeClassNames.put(OAuthConstants.ResponseModes.FORM_POST,
FormPostResponseModeProvider.class.getCanonicalName());
defaultResponseModeProviderClassName = DefaultResponseModeProvider.class.getCanonicalName();

for (Map.Entry<String, String> entry : supportedResponseModeClassNames.entrySet()) {
ResponseModeProvider responseModeProvider = (ResponseModeProvider)
Class.forName(entry.getValue()).newInstance();

supportedResponseModeProviders.put(entry.getKey(), responseModeProvider);
}

defaultResponseModeProvider = (ResponseModeProvider)
Class.forName(defaultResponseModeProviderClassName).newInstance();

OAuth2ServiceComponentHolder.setResponseModeProviders(supportedResponseModeProviders);
OAuth2ServiceComponentHolder.setDefaultResponseModeProvider(defaultResponseModeProvider);

setSupportedResponseModes();
response = oAuth2AuthzEndpoint.authorize(httpServletRequest, httpServletResponse);
} catch (InvalidRequestParentException ire) {
InvalidRequestExceptionMapper invalidRequestExceptionMapper = new InvalidRequestExceptionMapper();
Expand Down Expand Up @@ -1368,31 +1295,7 @@ public void testHandleUserConsent(boolean isRespDTONull, String consent, boolean

Response response;
try {
Map<String, ResponseModeProvider> supportedResponseModeProviders = new HashMap<>();
ResponseModeProvider defaultResponseModeProvider;
Map<String, String> supportedResponseModeClassNames = new HashMap<>();
String defaultResponseModeProviderClassName;
supportedResponseModeClassNames.put(OAuthConstants.ResponseModes.QUERY,
QueryResponseModeProvider.class.getCanonicalName());
supportedResponseModeClassNames.put(OAuthConstants.ResponseModes.FRAGMENT,
FragmentResponseModeProvider.class.getCanonicalName());
supportedResponseModeClassNames.put(OAuthConstants.ResponseModes.FORM_POST,
FormPostResponseModeProvider.class.getCanonicalName());
defaultResponseModeProviderClassName = DefaultResponseModeProvider.class.getCanonicalName();

for (Map.Entry<String, String> entry : supportedResponseModeClassNames.entrySet()) {
ResponseModeProvider responseModeProvider = (ResponseModeProvider)
Class.forName(entry.getValue()).newInstance();

supportedResponseModeProviders.put(entry.getKey(), responseModeProvider);
}

defaultResponseModeProvider = (ResponseModeProvider)
Class.forName(defaultResponseModeProviderClassName).newInstance();

OAuth2ServiceComponentHolder.setResponseModeProviders(supportedResponseModeProviders);
OAuth2ServiceComponentHolder.setDefaultResponseModeProvider(defaultResponseModeProvider);

setSupportedResponseModes();
response = oAuth2AuthzEndpoint.authorize(httpServletRequest, httpServletResponse);
} catch (InvalidRequestParentException ire) {
InvalidRequestExceptionMapper invalidRequestExceptionMapper = new InvalidRequestExceptionMapper();
Expand Down Expand Up @@ -1535,31 +1438,7 @@ public void testDoUserAuthz(String prompt, String idTokenHint, boolean hasUserAp

Response response;
try {
Map<String, ResponseModeProvider> supportedResponseModeProviders = new HashMap<>();
ResponseModeProvider defaultResponseModeProvider;
Map<String, String> supportedResponseModeClassNames = new HashMap<>();
String defaultResponseModeProviderClassName;
supportedResponseModeClassNames.put(OAuthConstants.ResponseModes.QUERY,
QueryResponseModeProvider.class.getCanonicalName());
supportedResponseModeClassNames.put(OAuthConstants.ResponseModes.FRAGMENT,
FragmentResponseModeProvider.class.getCanonicalName());
supportedResponseModeClassNames.put(OAuthConstants.ResponseModes.FORM_POST,
FormPostResponseModeProvider.class.getCanonicalName());
defaultResponseModeProviderClassName = DefaultResponseModeProvider.class.getCanonicalName();

for (Map.Entry<String, String> entry : supportedResponseModeClassNames.entrySet()) {
ResponseModeProvider responseModeProvider = (ResponseModeProvider)
Class.forName(entry.getValue()).newInstance();

supportedResponseModeProviders.put(entry.getKey(), responseModeProvider);
}

defaultResponseModeProvider = (ResponseModeProvider)
Class.forName(defaultResponseModeProviderClassName).newInstance();

OAuth2ServiceComponentHolder.setResponseModeProviders(supportedResponseModeProviders);
OAuth2ServiceComponentHolder.setDefaultResponseModeProvider(defaultResponseModeProvider);

setSupportedResponseModes();
response = oAuth2AuthzEndpoint.authorize(httpServletRequest, httpServletResponse);
} catch (InvalidRequestParentException ire) {
InvalidRequestExceptionMapper invalidRequestExceptionMapper = new InvalidRequestExceptionMapper();
Expand Down Expand Up @@ -1712,31 +1591,7 @@ public void testManageOIDCSessionState(Object cookieObject, Object sessionStateO

Response response;
try {
Map<String, ResponseModeProvider> supportedResponseModeProviders = new HashMap<>();
ResponseModeProvider defaultResponseModeProvider;
Map<String, String> supportedResponseModeClassNames = new HashMap<>();
String defaultResponseModeProviderClassName;
supportedResponseModeClassNames.put(OAuthConstants.ResponseModes.QUERY,
QueryResponseModeProvider.class.getCanonicalName());
supportedResponseModeClassNames.put(OAuthConstants.ResponseModes.FRAGMENT,
FragmentResponseModeProvider.class.getCanonicalName());
supportedResponseModeClassNames.put(OAuthConstants.ResponseModes.FORM_POST,
FormPostResponseModeProvider.class.getCanonicalName());
defaultResponseModeProviderClassName = DefaultResponseModeProvider.class.getCanonicalName();

for (Map.Entry<String, String> entry : supportedResponseModeClassNames.entrySet()) {
ResponseModeProvider responseModeProvider = (ResponseModeProvider)
Class.forName(entry.getValue()).newInstance();

supportedResponseModeProviders.put(entry.getKey(), responseModeProvider);
}

defaultResponseModeProvider = (ResponseModeProvider)
Class.forName(defaultResponseModeProviderClassName).newInstance();

OAuth2ServiceComponentHolder.setResponseModeProviders(supportedResponseModeProviders);
OAuth2ServiceComponentHolder.setDefaultResponseModeProvider(defaultResponseModeProvider);

setSupportedResponseModes();
response = oAuth2AuthzEndpoint.authorize(httpServletRequest, httpServletResponse);
} catch (InvalidRequestParentException ire) {
InvalidRequestExceptionMapper invalidRequestExceptionMapper = new InvalidRequestExceptionMapper();
Expand Down Expand Up @@ -2559,7 +2414,6 @@ private static Object[] appendValue(Object[] originalArray, Object value) {
return newArray;
}

@Test
public void testDeviceCodeGrantCachedClaims () throws Exception {
String userCode = "dummyUserCode";
String deviceCode = "dummyDeviceCode";
Expand Down Expand Up @@ -2594,4 +2448,33 @@ public void testDeviceCodeGrantCachedClaims () throws Exception {
method2.invoke(defaultOIDCClaimsCallbackHandler, deviceCode);
assertEquals(attributeFromCache.get(claimMapping), userAttributes.get(claimMapping));
}

private void setSupportedResponseModes() throws ClassNotFoundException, InstantiationException,
IllegalAccessException {

Map<String, ResponseModeProvider> supportedResponseModeProviders = new HashMap<>();
ResponseModeProvider defaultResponseModeProvider;
Map<String, String> supportedResponseModeClassNames = new HashMap<>();
String defaultResponseModeProviderClassName;
supportedResponseModeClassNames.put(OAuthConstants.ResponseModes.QUERY,
QueryResponseModeProvider.class.getCanonicalName());
supportedResponseModeClassNames.put(OAuthConstants.ResponseModes.FRAGMENT,
FragmentResponseModeProvider.class.getCanonicalName());
supportedResponseModeClassNames.put(OAuthConstants.ResponseModes.FORM_POST,
FormPostResponseModeProvider.class.getCanonicalName());
defaultResponseModeProviderClassName = DefaultResponseModeProvider.class.getCanonicalName();

for (Map.Entry<String, String> entry : supportedResponseModeClassNames.entrySet()) {
ResponseModeProvider responseModeProvider = (ResponseModeProvider)
Class.forName(entry.getValue()).newInstance();

supportedResponseModeProviders.put(entry.getKey(), responseModeProvider);
}

defaultResponseModeProvider = (ResponseModeProvider)
Class.forName(defaultResponseModeProviderClassName).newInstance();

OAuth2ServiceComponentHolder.setResponseModeProviders(supportedResponseModeProviders);
OAuth2ServiceComponentHolder.setDefaultResponseModeProvider(defaultResponseModeProvider);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
package org.wso2.carbon.identity.oauth2.responsemode.provider.impl;

import org.apache.commons.lang.StringUtils;
import org.wso2.carbon.identity.application.authentication.framework.util.FrameworkUtils;
import org.wso2.carbon.identity.oauth.common.OAuthConstants;
import org.wso2.carbon.identity.oauth2.responsemode.provider.AbstractResponseModeProvider;
import org.wso2.carbon.identity.oauth2.responsemode.provider.AuthorizationResponseDTO;
Expand Down Expand Up @@ -55,44 +54,41 @@ public String getAuthResponseRedirectUrl(AuthorizationResponseDTO authorizationR
long validityPeriod = authorizationResponseDTO.getSuccessResponseDTO().getValidityPeriod();
String scope = authorizationResponseDTO.getSuccessResponseDTO().getScope();
String authenticatedIdPs = authorizationResponseDTO.getAuthenticatedIDPs();
List<String> queryParams = new ArrayList<>();
List<String> params = new ArrayList<>();
if (accessToken != null) {
queryParams.add(OAuthConstants.ACCESS_TOKEN_RESPONSE_PARAM + "=" + accessToken);
queryParams.add(OAuthConstants.EXPIRES_IN + "=" + validityPeriod);
params.add(OAuthConstants.ACCESS_TOKEN_RESPONSE_PARAM + "=" + accessToken);
params.add(OAuthConstants.EXPIRES_IN + "=" + validityPeriod);
}

if (tokenType != null) {
queryParams.add(OAuthConstants.TOKEN_TYPE + "=" + tokenType);
params.add(OAuthConstants.TOKEN_TYPE + "=" + tokenType);
}

if (idToken != null) {
queryParams.add(OAuthConstants.ID_TOKEN + "=" + idToken);
params.add(OAuthConstants.ID_TOKEN + "=" + idToken);
}

if (code != null) {
queryParams.add(OAuthConstants.CODE + "=" + code);
params.add(OAuthConstants.CODE + "=" + code);
}

if (authenticatedIdPs != null && !authenticatedIdPs.isEmpty()) {
queryParams.add(OAuthConstants.AUTHENTICATED_IDPS + "=" + authenticatedIdPs);
params.add(OAuthConstants.AUTHENTICATED_IDPS + "=" + authenticatedIdPs);
}

if (sessionState != null) {
queryParams.add(OAuthConstants.SESSION_STATE + "=" + sessionState);
params.add(OAuthConstants.SESSION_STATE + "=" + sessionState);
}

if (state != null) {
queryParams.add(OAuthConstants.STATE + "=" + state);
params.add(OAuthConstants.STATE + "=" + state);
}

if (scope != null) {
queryParams.add(OAuthConstants.SCOPE + "=" + scope);
params.add(OAuthConstants.SCOPE + "=" + scope);
}

redirectUrl = FrameworkUtils.appendQueryParamsStringToUrl(redirectUrl,
String.join("&", queryParams));

redirectUrl = redirectUrl.replace("?", "#");
redirectUrl += "#" + String.join("&", params);

} else {
redirectUrl += "#" +
Expand Down
Loading

0 comments on commit 83f4cc0

Please sign in to comment.