Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preserve already existing query params in callback url when fragment response mode is used. #2158

Original file line number Diff line number Diff line change
Expand Up @@ -551,31 +551,7 @@ public void testAuthorize(Object flowStatusObject, String[] clientId, String ses
when(IdentityDatabaseUtil.getDBConnection()).thenReturn(connection);
mockServiceURLBuilder();
try {
Map<String, ResponseModeProvider> supportedResponseModeProviders = new HashMap<>();
ResponseModeProvider defaultResponseModeProvider;
Map<String, String> supportedResponseModeClassNames = new HashMap<>();
String defaultResponseModeProviderClassName;
supportedResponseModeClassNames.put(OAuthConstants.ResponseModes.QUERY,
QueryResponseModeProvider.class.getCanonicalName());
supportedResponseModeClassNames.put(OAuthConstants.ResponseModes.FRAGMENT,
FragmentResponseModeProvider.class.getCanonicalName());
supportedResponseModeClassNames.put(OAuthConstants.ResponseModes.FORM_POST,
FormPostResponseModeProvider.class.getCanonicalName());
defaultResponseModeProviderClassName = DefaultResponseModeProvider.class.getCanonicalName();

for (Map.Entry<String, String> entry : supportedResponseModeClassNames.entrySet()) {
ResponseModeProvider responseModeProvider = (ResponseModeProvider)
Class.forName(entry.getValue()).newInstance();

supportedResponseModeProviders.put(entry.getKey(), responseModeProvider);
}

defaultResponseModeProvider = (ResponseModeProvider)
Class.forName(defaultResponseModeProviderClassName).newInstance();

OAuth2ServiceComponentHolder.setResponseModeProviders(supportedResponseModeProviders);
OAuth2ServiceComponentHolder.setDefaultResponseModeProvider(defaultResponseModeProvider);

setSupportedResponseModes();
response = oAuth2AuthzEndpoint.authorize(httpServletRequest, httpServletResponse);
} catch (InvalidRequestParentException ire) {
InvalidRequestExceptionMapper invalidRequestExceptionMapper = new InvalidRequestExceptionMapper();
Expand Down Expand Up @@ -760,32 +736,7 @@ public void testAuthorizeForAuthenticationResponse(boolean isResultInRequest, bo
anyString(), isNull(), anyInt(), anyList())).thenReturn(true);

mockServiceURLBuilder();

Map<String, ResponseModeProvider> supportedResponseModeProviders = new HashMap<>();
ResponseModeProvider defaultResponseModeProvider;
Map<String, String> supportedResponseModeClassNames = new HashMap<>();
String defaultResponseModeProviderClassName;
supportedResponseModeClassNames.put(OAuthConstants.ResponseModes.QUERY,
QueryResponseModeProvider.class.getCanonicalName());
supportedResponseModeClassNames.put(OAuthConstants.ResponseModes.FRAGMENT,
FragmentResponseModeProvider.class.getCanonicalName());
supportedResponseModeClassNames.put(OAuthConstants.ResponseModes.FORM_POST,
FormPostResponseModeProvider.class.getCanonicalName());
defaultResponseModeProviderClassName = DefaultResponseModeProvider.class.getCanonicalName();

for (Map.Entry<String, String> entry : supportedResponseModeClassNames.entrySet()) {
ResponseModeProvider responseModeProvider = (ResponseModeProvider)
Class.forName(entry.getValue()).newInstance();

supportedResponseModeProviders.put(entry.getKey(), responseModeProvider);
}

defaultResponseModeProvider = (ResponseModeProvider)
Class.forName(defaultResponseModeProviderClassName).newInstance();

OAuth2ServiceComponentHolder.setResponseModeProviders(supportedResponseModeProviders);
OAuth2ServiceComponentHolder.setDefaultResponseModeProvider(defaultResponseModeProvider);

setSupportedResponseModes();
Response response = oAuth2AuthzEndpoint.authorize(httpServletRequest, httpServletResponse);
assertEquals(response.getStatus(), expected, "Unexpected HTTP response status");
if (!isAuthenticated) {
Expand Down Expand Up @@ -956,31 +907,7 @@ public void testUserConsentResponse(String consent, String redirectUrl, Set<Stri

Response response;
try {
Map<String, ResponseModeProvider> supportedResponseModeProviders = new HashMap<>();
ResponseModeProvider defaultResponseModeProvider;
Map<String, String> supportedResponseModeClassNames = new HashMap<>();
String defaultResponseModeProviderClassName;
supportedResponseModeClassNames.put(OAuthConstants.ResponseModes.QUERY,
QueryResponseModeProvider.class.getCanonicalName());
supportedResponseModeClassNames.put(OAuthConstants.ResponseModes.FRAGMENT,
FragmentResponseModeProvider.class.getCanonicalName());
supportedResponseModeClassNames.put(OAuthConstants.ResponseModes.FORM_POST,
FormPostResponseModeProvider.class.getCanonicalName());
defaultResponseModeProviderClassName = DefaultResponseModeProvider.class.getCanonicalName();

for (Map.Entry<String, String> entry : supportedResponseModeClassNames.entrySet()) {
ResponseModeProvider responseModeProvider = (ResponseModeProvider)
Class.forName(entry.getValue()).newInstance();

supportedResponseModeProviders.put(entry.getKey(), responseModeProvider);
}

defaultResponseModeProvider = (ResponseModeProvider)
Class.forName(defaultResponseModeProviderClassName).newInstance();

OAuth2ServiceComponentHolder.setResponseModeProviders(supportedResponseModeProviders);
OAuth2ServiceComponentHolder.setDefaultResponseModeProvider(defaultResponseModeProvider);

setSupportedResponseModes();
response = oAuth2AuthzEndpoint.authorize(httpServletRequest, httpServletResponse);
} catch (InvalidRequestParentException ire) {
InvalidRequestExceptionMapper invalidRequestExceptionMapper = new InvalidRequestExceptionMapper();
Expand Down Expand Up @@ -1366,31 +1293,7 @@ public void testHandleUserConsent(boolean isRespDTONull, String consent, boolean

Response response;
try {
Map<String, ResponseModeProvider> supportedResponseModeProviders = new HashMap<>();
ResponseModeProvider defaultResponseModeProvider;
Map<String, String> supportedResponseModeClassNames = new HashMap<>();
String defaultResponseModeProviderClassName;
supportedResponseModeClassNames.put(OAuthConstants.ResponseModes.QUERY,
QueryResponseModeProvider.class.getCanonicalName());
supportedResponseModeClassNames.put(OAuthConstants.ResponseModes.FRAGMENT,
FragmentResponseModeProvider.class.getCanonicalName());
supportedResponseModeClassNames.put(OAuthConstants.ResponseModes.FORM_POST,
FormPostResponseModeProvider.class.getCanonicalName());
defaultResponseModeProviderClassName = DefaultResponseModeProvider.class.getCanonicalName();

for (Map.Entry<String, String> entry : supportedResponseModeClassNames.entrySet()) {
ResponseModeProvider responseModeProvider = (ResponseModeProvider)
Class.forName(entry.getValue()).newInstance();

supportedResponseModeProviders.put(entry.getKey(), responseModeProvider);
}

defaultResponseModeProvider = (ResponseModeProvider)
Class.forName(defaultResponseModeProviderClassName).newInstance();

OAuth2ServiceComponentHolder.setResponseModeProviders(supportedResponseModeProviders);
OAuth2ServiceComponentHolder.setDefaultResponseModeProvider(defaultResponseModeProvider);

setSupportedResponseModes();
response = oAuth2AuthzEndpoint.authorize(httpServletRequest, httpServletResponse);
} catch (InvalidRequestParentException ire) {
InvalidRequestExceptionMapper invalidRequestExceptionMapper = new InvalidRequestExceptionMapper();
Expand Down Expand Up @@ -1533,31 +1436,7 @@ public void testDoUserAuthz(String prompt, String idTokenHint, boolean hasUserAp

Response response;
try {
Map<String, ResponseModeProvider> supportedResponseModeProviders = new HashMap<>();
ResponseModeProvider defaultResponseModeProvider;
Map<String, String> supportedResponseModeClassNames = new HashMap<>();
String defaultResponseModeProviderClassName;
supportedResponseModeClassNames.put(OAuthConstants.ResponseModes.QUERY,
QueryResponseModeProvider.class.getCanonicalName());
supportedResponseModeClassNames.put(OAuthConstants.ResponseModes.FRAGMENT,
FragmentResponseModeProvider.class.getCanonicalName());
supportedResponseModeClassNames.put(OAuthConstants.ResponseModes.FORM_POST,
FormPostResponseModeProvider.class.getCanonicalName());
defaultResponseModeProviderClassName = DefaultResponseModeProvider.class.getCanonicalName();

for (Map.Entry<String, String> entry : supportedResponseModeClassNames.entrySet()) {
ResponseModeProvider responseModeProvider = (ResponseModeProvider)
Class.forName(entry.getValue()).newInstance();

supportedResponseModeProviders.put(entry.getKey(), responseModeProvider);
}

defaultResponseModeProvider = (ResponseModeProvider)
Class.forName(defaultResponseModeProviderClassName).newInstance();

OAuth2ServiceComponentHolder.setResponseModeProviders(supportedResponseModeProviders);
OAuth2ServiceComponentHolder.setDefaultResponseModeProvider(defaultResponseModeProvider);

setSupportedResponseModes();
response = oAuth2AuthzEndpoint.authorize(httpServletRequest, httpServletResponse);
} catch (InvalidRequestParentException ire) {
InvalidRequestExceptionMapper invalidRequestExceptionMapper = new InvalidRequestExceptionMapper();
Expand Down Expand Up @@ -1710,31 +1589,7 @@ public void testManageOIDCSessionState(Object cookieObject, Object sessionStateO

Response response;
try {
Map<String, ResponseModeProvider> supportedResponseModeProviders = new HashMap<>();
ResponseModeProvider defaultResponseModeProvider;
Map<String, String> supportedResponseModeClassNames = new HashMap<>();
String defaultResponseModeProviderClassName;
supportedResponseModeClassNames.put(OAuthConstants.ResponseModes.QUERY,
QueryResponseModeProvider.class.getCanonicalName());
supportedResponseModeClassNames.put(OAuthConstants.ResponseModes.FRAGMENT,
FragmentResponseModeProvider.class.getCanonicalName());
supportedResponseModeClassNames.put(OAuthConstants.ResponseModes.FORM_POST,
FormPostResponseModeProvider.class.getCanonicalName());
defaultResponseModeProviderClassName = DefaultResponseModeProvider.class.getCanonicalName();

for (Map.Entry<String, String> entry : supportedResponseModeClassNames.entrySet()) {
ResponseModeProvider responseModeProvider = (ResponseModeProvider)
Class.forName(entry.getValue()).newInstance();

supportedResponseModeProviders.put(entry.getKey(), responseModeProvider);
}

defaultResponseModeProvider = (ResponseModeProvider)
Class.forName(defaultResponseModeProviderClassName).newInstance();

OAuth2ServiceComponentHolder.setResponseModeProviders(supportedResponseModeProviders);
OAuth2ServiceComponentHolder.setDefaultResponseModeProvider(defaultResponseModeProvider);

setSupportedResponseModes();
response = oAuth2AuthzEndpoint.authorize(httpServletRequest, httpServletResponse);
} catch (InvalidRequestParentException ire) {
InvalidRequestExceptionMapper invalidRequestExceptionMapper = new InvalidRequestExceptionMapper();
Expand Down Expand Up @@ -2556,7 +2411,6 @@ private static Object[] appendValue(Object[] originalArray, Object value) {
return newArray;
}

@Test
public void testDeviceCodeGrantCachedClaims () throws Exception {
String userCode = "dummyUserCode";
String deviceCode = "dummyDeviceCode";
Expand Down Expand Up @@ -2591,4 +2445,33 @@ public void testDeviceCodeGrantCachedClaims () throws Exception {
method2.invoke(defaultOIDCClaimsCallbackHandler, deviceCode);
assertEquals(attributeFromCache.get(claimMapping), userAttributes.get(claimMapping));
}

private void setSupportedResponseModes() throws ClassNotFoundException, InstantiationException,
IllegalAccessException {

Map<String, ResponseModeProvider> supportedResponseModeProviders = new HashMap<>();
ResponseModeProvider defaultResponseModeProvider;
Map<String, String> supportedResponseModeClassNames = new HashMap<>();
String defaultResponseModeProviderClassName;
supportedResponseModeClassNames.put(OAuthConstants.ResponseModes.QUERY,
QueryResponseModeProvider.class.getCanonicalName());
supportedResponseModeClassNames.put(OAuthConstants.ResponseModes.FRAGMENT,
FragmentResponseModeProvider.class.getCanonicalName());
supportedResponseModeClassNames.put(OAuthConstants.ResponseModes.FORM_POST,
FormPostResponseModeProvider.class.getCanonicalName());
defaultResponseModeProviderClassName = DefaultResponseModeProvider.class.getCanonicalName();

for (Map.Entry<String, String> entry : supportedResponseModeClassNames.entrySet()) {
ResponseModeProvider responseModeProvider = (ResponseModeProvider)
Class.forName(entry.getValue()).newInstance();

supportedResponseModeProviders.put(entry.getKey(), responseModeProvider);
}

defaultResponseModeProvider = (ResponseModeProvider)
Class.forName(defaultResponseModeProviderClassName).newInstance();

OAuth2ServiceComponentHolder.setResponseModeProviders(supportedResponseModeProviders);
OAuth2ServiceComponentHolder.setDefaultResponseModeProvider(defaultResponseModeProvider);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
package org.wso2.carbon.identity.oauth2.responsemode.provider.impl;

import org.apache.commons.lang.StringUtils;
import org.wso2.carbon.identity.application.authentication.framework.util.FrameworkUtils;
import org.wso2.carbon.identity.oauth.common.OAuthConstants;
import org.wso2.carbon.identity.oauth2.responsemode.provider.AbstractResponseModeProvider;
import org.wso2.carbon.identity.oauth2.responsemode.provider.AuthorizationResponseDTO;
Expand Down Expand Up @@ -55,44 +54,41 @@ public String getAuthResponseRedirectUrl(AuthorizationResponseDTO authorizationR
long validityPeriod = authorizationResponseDTO.getSuccessResponseDTO().getValidityPeriod();
String scope = authorizationResponseDTO.getSuccessResponseDTO().getScope();
String authenticatedIdPs = authorizationResponseDTO.getAuthenticatedIDPs();
List<String> queryParams = new ArrayList<>();
List<String> params = new ArrayList<>();
if (accessToken != null) {
queryParams.add(OAuthConstants.ACCESS_TOKEN_RESPONSE_PARAM + "=" + accessToken);
queryParams.add(OAuthConstants.EXPIRES_IN + "=" + validityPeriod);
params.add(OAuthConstants.ACCESS_TOKEN_RESPONSE_PARAM + "=" + accessToken);
params.add(OAuthConstants.EXPIRES_IN + "=" + validityPeriod);
}

if (tokenType != null) {
queryParams.add(OAuthConstants.TOKEN_TYPE + "=" + tokenType);
params.add(OAuthConstants.TOKEN_TYPE + "=" + tokenType);
}

if (idToken != null) {
queryParams.add(OAuthConstants.ID_TOKEN + "=" + idToken);
params.add(OAuthConstants.ID_TOKEN + "=" + idToken);
}

if (code != null) {
queryParams.add(OAuthConstants.CODE + "=" + code);
params.add(OAuthConstants.CODE + "=" + code);
}

if (authenticatedIdPs != null && !authenticatedIdPs.isEmpty()) {
queryParams.add(OAuthConstants.AUTHENTICATED_IDPS + "=" + authenticatedIdPs);
params.add(OAuthConstants.AUTHENTICATED_IDPS + "=" + authenticatedIdPs);
}

if (sessionState != null) {
queryParams.add(OAuthConstants.SESSION_STATE + "=" + sessionState);
params.add(OAuthConstants.SESSION_STATE + "=" + sessionState);
}

if (state != null) {
queryParams.add(OAuthConstants.STATE + "=" + state);
params.add(OAuthConstants.STATE + "=" + state);
}

if (scope != null) {
queryParams.add(OAuthConstants.SCOPE + "=" + scope);
params.add(OAuthConstants.SCOPE + "=" + scope);
}

redirectUrl = FrameworkUtils.appendQueryParamsStringToUrl(redirectUrl,
String.join("&", queryParams));

redirectUrl = redirectUrl.replace("?", "#");
redirectUrl += "#" + String.join("&", params);

} else {
redirectUrl += "#" +
Expand Down
Loading