Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ignore claims outside request object #2160

Merged
merged 14 commits into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2399,7 +2399,7 @@ private String populateOauthParameters(OAuth2Parameters params, OAuthMessage oAu
}

if (isPkceSupportEnabled()) {
String pkceChallengeCode = getPkceCodeChallenge(oAuthMessage, params);
String pkceChallengeCode = getPkceCodeChallenge(oAuthMessage, params, validationResponse.isPkceMandatory());
String pkceChallengeMethod = getPkceCodeChallengeMethod(oAuthMessage, params);

String redirectURI = validatePKCEParameters(oAuthMessage, validationResponse, pkceChallengeCode,
Expand Down Expand Up @@ -4083,9 +4083,10 @@ private OAuth2Parameters getOAuth2ParamsFromOAuthMessage(OAuthMessage oAuthMessa
*
* @param oAuthMessage oAuthMessage
* @param params OAuth2 Parameters
* @param isPkcemandatory is PKCE mandatory
* @return PKCE code challenge. Priority will be given to the value inside the OAuth2Parameters.
*/
private String getPkceCodeChallenge(OAuthMessage oAuthMessage, OAuth2Parameters params)
private String getPkceCodeChallenge(OAuthMessage oAuthMessage, OAuth2Parameters params, boolean isPkcemandatory)
throws InvalidRequestException {

String pkceChallengeCode = null;
Expand All @@ -4094,8 +4095,8 @@ private String getPkceCodeChallenge(OAuthMessage oAuthMessage, OAuth2Parameters
if (params.getPkceCodeChallenge() != null) {
// If Oauth2 params contains code_challenge get value from Oauth2 params.
pkceChallengeCode = params.getPkceCodeChallenge();
} else if (!isFapiConformantApp) {
// Else retrieve from request query params if application is not FAPI compliant.
} else if (!isFapiConformantApp || isPkcemandatory) {
// Else retrieve from request query params if application is not FAPI compliant or PKCE is mandatory.
pkceChallengeCode = oAuthMessage.getOauthPKCECodeChallenge();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,23 +255,26 @@ private void validateInputParameters(HttpServletRequest request) throws ParClien
private void validateRequestObject(OAuthAuthzRequest oAuthAuthzRequest) throws ParCoreException {

try {
if (OAuth2Util.isOIDCAuthzRequest(oAuthAuthzRequest.getScopes()) &&
StringUtils.isNotBlank(oAuthAuthzRequest.getParam(REQUEST))) {

OAuth2Parameters parameters = new OAuth2Parameters();
parameters.setClientId(oAuthAuthzRequest.getClientId());
parameters.setRedirectURI(oAuthAuthzRequest.getRedirectURI());
parameters.setResponseType(oAuthAuthzRequest.getResponseType());
parameters.setTenantDomain(getSPTenantDomainFromClientId(oAuthAuthzRequest.getClientId()));

RequestObject requestObject = OIDCRequestObjectUtil.buildRequestObject(oAuthAuthzRequest, parameters);
if (requestObject == null) {
throw new ParClientException(OAuth2ErrorCodes.INVALID_REQUEST, ParConstants.INVALID_REQUEST_OBJECT);
if (OAuth2Util.isOIDCAuthzRequest(oAuthAuthzRequest.getScopes())) {
if (StringUtils.isNotBlank(oAuthAuthzRequest.getParam(REQUEST))) {

OAuth2Parameters parameters = new OAuth2Parameters();
parameters.setClientId(oAuthAuthzRequest.getClientId());
parameters.setRedirectURI(oAuthAuthzRequest.getRedirectURI());
parameters.setResponseType(oAuthAuthzRequest.getResponseType());
parameters.setTenantDomain(getSPTenantDomainFromClientId(oAuthAuthzRequest.getClientId()));

RequestObject requestObject =
OIDCRequestObjectUtil.buildRequestObject(oAuthAuthzRequest, parameters);
if (requestObject == null) {
throw new ParClientException(OAuth2ErrorCodes.INVALID_REQUEST,
ParConstants.INVALID_REQUEST_OBJECT);
}
} else if (isFAPIConformantApp(oAuthAuthzRequest.getClientId())) {
/* Mandate request object for FAPI requests
https://openid.net/specs/openid-financial-api-part-2-1_0.html#authorization-server (5.2.2-1) */
throw new ParClientException(OAuth2ErrorCodes.INVALID_REQUEST, ParConstants.REQUEST_OBJECT_MISSING);
}
} else if (isFAPIConformantApp(oAuthAuthzRequest.getClientId())) {
// Mandate request object for FAPI requests
// https://openid.net/specs/openid-financial-api-part-2-1_0.html#authorization-server (5.2.2-1)
throw new ParClientException(OAuth2ErrorCodes.INVALID_REQUEST, ParConstants.REQUEST_OBJECT_MISSING);
}
} catch (RequestObjectException e) {
if (OAuth2ErrorCodes.SERVER_ERROR.equals(e.getErrorCode())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
import org.wso2.carbon.identity.oauth2.OAuth2Service;
import org.wso2.carbon.identity.oauth2.OAuth2TokenValidationService;
import org.wso2.carbon.identity.oauth2.Oauth2ScopeConstants;
import org.wso2.carbon.identity.oauth2.RequestObjectException;
import org.wso2.carbon.identity.oauth2.bean.OAuthClientAuthnContext;
import org.wso2.carbon.identity.oauth2.bean.Scope;
import org.wso2.carbon.identity.oauth2.dto.OAuth2ClientValidationResponseDTO;
Expand All @@ -100,7 +101,11 @@
import org.wso2.carbon.identity.oauth2.scopeservice.OAuth2Resource;
import org.wso2.carbon.identity.oauth2.scopeservice.ScopeMetadataService;
import org.wso2.carbon.identity.oauth2.util.OAuth2Util;
import org.wso2.carbon.identity.openidconnect.OIDCRequestObjectUtil;
import org.wso2.carbon.identity.openidconnect.RequestObjectBuilder;
import org.wso2.carbon.identity.openidconnect.RequestObjectService;
import org.wso2.carbon.identity.openidconnect.RequestObjectValidator;
import org.wso2.carbon.identity.openidconnect.model.RequestObject;
import org.wso2.carbon.identity.webfinger.DefaultWebFingerProcessor;
import org.wso2.carbon.identity.webfinger.WebFingerProcessor;
import org.wso2.carbon.idp.mgt.IdentityProviderManagementException;
Expand Down Expand Up @@ -1712,7 +1717,18 @@ public static void setParAuthService(ParAuthService parAuthService) {
public static String retrieveStateForErrorURL(HttpServletRequest request, OAuth2Parameters oAuth2Parameters) {

String state = null;
if (oAuth2Parameters != null && oAuth2Parameters.getState() != null) {

if (request.getParameter(OAuthConstants.OAuth20Params.REQUEST) != null) {
String stateInsideRequestObj = getStateFromRequestObject(request, oAuth2Parameters);
if (StringUtils.isNotBlank(stateInsideRequestObj)) {
state = stateInsideRequestObj;
if (log.isDebugEnabled()) {
log.debug("Retrieved state value " + state + " from request object.");
}
}
}

if (StringUtils.isEmpty(state) && oAuth2Parameters != null && oAuth2Parameters.getState() != null) {
RivinduM marked this conversation as resolved.
Show resolved Hide resolved
state = oAuth2Parameters.getState();
if (log.isDebugEnabled()) {
log.debug("Retrieved state value " + state + " from OAuth2Parameters.");
Expand All @@ -1727,6 +1743,32 @@ public static String retrieveStateForErrorURL(HttpServletRequest request, OAuth2
return state;
}

private static String getStateFromRequestObject(HttpServletRequest request, OAuth2Parameters oAuth2Parameters) {

try {
RequestObjectValidator requestObjectValidator = OAuthServerConfiguration.getInstance()
.getRequestObjectValidator();
RequestObjectBuilder requestObjectBuilder = OAuthServerConfiguration.getInstance()
.getRequestObjectBuilders().get("request_param_value_builder");
RivinduM marked this conversation as resolved.
Show resolved Hide resolved
RequestObject requestObject =
requestObjectBuilder.buildRequestObject(request.getParameter(OAuthConstants.OAuth20Params.REQUEST),
oAuth2Parameters);
if (StringUtils.isBlank(oAuth2Parameters.getClientId())) {
// Set client id and tenant domain required for signature validation if not already set.
String clientId = request.getParameter(PROP_CLIENT_ID);
oAuth2Parameters.setClientId(clientId);
oAuth2Parameters.setTenantDomain(getSPTenantDomainFromClientId(clientId));
}
// Validate request object signature to ensure request object is not tampered.
OIDCRequestObjectUtil.validateRequestObjectSignature(oAuth2Parameters, requestObject,
requestObjectValidator);
return requestObject.getClaimValue(OAuthConstants.OAuth20Params.STATE);
} catch (RequestObjectException e) {
log.debug("Error while retrieving state from request object.", e);
RivinduM marked this conversation as resolved.
Show resolved Hide resolved
}
return null;
}

/**
* Return updated redirect URL.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
import javax.ws.rs.core.MultivaluedMap;
import javax.ws.rs.core.Response;

import static org.mockito.ArgumentMatchers.contains;
import static org.mockito.ArgumentMatchers.isNull;
import static org.mockito.ArgumentMatchers.nullable;
import static org.mockito.Matchers.any;
Expand Down Expand Up @@ -208,6 +209,7 @@ public class EndpointUtilTest extends PowerMockIdentityBaseTest {
private static final String REQUESTED_OIDC_SCOPES_KEY = "requested_oidc_scopes=";
private static final String REQUESTED_OIDC_SCOPES_VALUES = "openid+profile";
private static final String EXTERNAL_CONSENTED_APP_NAME = "testApp";
private static final String REDIRECT = "redirect";
private static final String EXTERNAL_CONSENT_URL = "https://localhost:9443/consent";
private String username;
private String password;
Expand Down Expand Up @@ -646,7 +648,7 @@ public void testGetErrorPageURL(boolean isImplicitResponse, boolean isHybridResp
when(OAuth2Util.OAuthURL.getOAuth2ErrorPageUrl()).thenReturn(ERROR_PAGE_URL);

when(mockedOAuthResponse.getLocationUri()).thenReturn("http://localhost:8080/location");
when(mockedHttpServletRequest.getParameter(anyString())).thenReturn("http://localhost:8080/location");
when(mockedHttpServletRequest.getParameter(contains(REDIRECT))).thenReturn("http://localhost:8080/location");

String url = EndpointUtil.getErrorPageURL(mockedHttpServletRequest, "invalid request",
"invalid request object", "invalid request", "test", parameters);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

package org.wso2.carbon.identity.oauth2.model;

import org.apache.commons.lang3.StringUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.oltu.oauth2.as.request.OAuthAuthzRequest;
Expand All @@ -27,17 +26,12 @@
import org.apache.oltu.oauth2.common.exception.OAuthSystemException;
import org.apache.oltu.oauth2.common.utils.OAuthUtils;
import org.apache.oltu.oauth2.common.validators.OAuthValidator;
import org.json.JSONObject;
import org.wso2.carbon.identity.central.log.mgt.utils.LogConstants;
import org.wso2.carbon.identity.central.log.mgt.utils.LoggerUtils;
import org.wso2.carbon.identity.oauth.common.OAuthConstants;
import org.wso2.carbon.identity.oauth.config.OAuthServerConfiguration;
import org.wso2.carbon.identity.openidconnect.model.Constants;
import org.wso2.carbon.utils.DiagnosticLog;

import java.nio.charset.StandardCharsets;
import java.util.Base64;

import javax.servlet.http.HttpServletRequest;

/**
Expand Down Expand Up @@ -85,25 +79,4 @@ protected OAuthValidator<HttpServletRequest> initValidator() throws OAuthProblem

return OAuthUtils.instantiateClass(clazz);
}

@Override
public String getState() {

/*If request object is present, get the state from the request object.
This state value was required to overridden from the request object in order to make sure the correct state
value(value inside the request object) is sent in error responses prior to building the request object.*/
if (StringUtils.isNotBlank(getParam(Constants.REQUEST))) {
byte[] requestObject;
try {
requestObject = Base64.getDecoder().decode(getParam(Constants.REQUEST).split("\\.")[1]);
} catch (IllegalArgumentException e) {
// Decode if the requestObject is base64-url encoded.
requestObject = Base64.getUrlDecoder().decode(getParam(Constants.REQUEST).split("\\.")[1]);
}
JSONObject requestObjectJson = new JSONObject(new String(requestObject, StandardCharsets.UTF_8));
return requestObjectJson.has(OAuth.OAUTH_STATE) ? requestObjectJson.getString(OAuth.OAUTH_STATE) : null;
} else {
return super.getState();
}
}
}