Skip to content

Commit

Permalink
Address review comments: 2
Browse files Browse the repository at this point in the history
  • Loading branch information
ThaminduDilshan committed Sep 22, 2023
1 parent 8530644 commit 26a76a1
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
public class OAuthAppTenantResolverValve extends ValveBase {

private static final Log log = LogFactory.getLog(OAuthAppTenantResolverValve.class);
private static final String OAUTH_SERVER_BASE_URL = IdentityUtil.getServerURL("/oauth", true, true) + "/";
private static final String OAUTH2_SERVER_BASE_URL = IdentityUtil.getServerURL("/oauth2", true, true) + "/";

@Override
public void invoke(Request request, Response response) throws IOException, ServletException {
Expand Down Expand Up @@ -112,9 +114,9 @@ public void invoke(Request request, Response response) throws IOException, Servl
*/
private boolean isOAuthRequest(Request request) {

String requestUri = request.getRequestURI();
return StringUtils.isNotEmpty(requestUri) && (requestUri.startsWith("/oauth/") ||
requestUri.startsWith("/oauth2/"));
String requestUrl = request.getRequestURL().toString();
return StringUtils.isNotEmpty(requestUrl) && (requestUrl.startsWith(OAUTH2_SERVER_BASE_URL) ||
requestUrl.startsWith(OAUTH_SERVER_BASE_URL));
}

/**
Expand All @@ -125,8 +127,8 @@ private boolean isOAuthRequest(Request request) {
*/
private boolean isOAuth10ARequest(Request request) {

String requestUri = request.getRequestURI();
return StringUtils.isNotEmpty(requestUri) && requestUri.startsWith("/oauth/");
String requestUrl = request.getRequestURL().toString();
return StringUtils.isNotEmpty(requestUrl) && requestUrl.startsWith(OAUTH_SERVER_BASE_URL);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,22 @@
import org.powermock.api.mockito.PowerMockito;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.testng.PowerMockTestCase;
import org.testng.Assert;
import org.powermock.reflect.Whitebox;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;
import org.wso2.carbon.identity.application.authentication.framework.model.AuthenticatedUser;
import org.wso2.carbon.identity.central.log.mgt.utils.LoggerUtils;
import org.wso2.carbon.identity.core.util.IdentityTenantUtil;
import org.wso2.carbon.identity.core.util.IdentityUtil;
import org.wso2.carbon.identity.oauth.common.exception.InvalidOAuthClientException;
import org.wso2.carbon.identity.oauth.config.OAuthServerConfiguration;
import org.wso2.carbon.identity.oauth.dao.OAuthAppDO;
import org.wso2.carbon.identity.oauth2.IdentityOAuth2Exception;
import org.wso2.carbon.identity.oauth2.client.authentication.OAuthClientAuthnException;
import org.wso2.carbon.identity.oauth2.util.OAuth2Util;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

import javax.servlet.ServletException;

import static org.mockito.Mockito.mock;
Expand All @@ -55,13 +55,13 @@
import static org.wso2.carbon.identity.oauth.common.OAuthConstants.OAuth20Params.CLIENT_ID;
import static org.wso2.carbon.identity.oauth.common.OAuthConstants.TENANT_NAME_FROM_CONTEXT;

@PrepareForTest({OAuthAppTenantResolverValve.class, IdentityTenantUtil.class, OAuth2Util.class,
LoggerUtils.class, OAuthServerConfiguration.class})
@PrepareForTest({OAuthAppTenantResolverValve.class, OAuth2Util.class, LoggerUtils.class,
OAuthServerConfiguration.class, IdentityUtil.class})
public class OAuthAppTenantResolverValveTest extends PowerMockTestCase {

private static final String DUMMY_RESOURCE_OAUTH_2 = "/oauth2/test/resource";
private static final String DUMMY_RESOURCE_OAUTH_10A = "/oauth/test/resource";
private static final String DUMMY_RESOURCE_NON_OAUTH = "/test/resource";
private static final String DUMMY_RESOURCE_OAUTH_2 = "https://localhost:9443/oauth2/test/resource";
private static final String DUMMY_RESOURCE_OAUTH_10A = "https://localhost:9443/oauth/test/resource";
private static final String DUMMY_RESOURCE_NON_OAUTH = "https://localhost:9443/test/resource";
private static final String DUMMY_CLIENT_ID = "client_id";
private static final String DUMMY_CLIENT_SECRET = "client_id";
private static final String TENANT_DOMAIN = "test.tenant";
Expand All @@ -76,11 +76,14 @@ public class OAuthAppTenantResolverValveTest extends PowerMockTestCase {
private OAuthAppTenantResolverValve oAuthAppTenantResolverValve;
private OAuthAppDO oAuthAppDO;

@BeforeMethod
public void setUp() throws Exception {
private ThreadLocal<Map<String, Object>> threadLocalProperties = new ThreadLocal<Map<String, Object>>() {
protected Map<String, Object> initialValue() {
return new HashMap();
}
};

mockStatic(IdentityTenantUtil.class);
when(IdentityTenantUtil.isTenantQualifiedUrlsEnabled()).thenReturn(false);
@BeforeMethod
public void setUp() {

AuthenticatedUser user = new AuthenticatedUser();
user.setTenantDomain(TENANT_DOMAIN);
Expand All @@ -97,8 +100,13 @@ public void setUp() throws Exception {
when(LoggerUtils.isDiagnosticLogsEnabled()).thenReturn(false);
mockStatic(OAuth2Util.class);

mockStatic(IdentityUtil.class);
threadLocalProperties.get().remove(TENANT_NAME_FROM_CONTEXT);
Whitebox.setInternalState(IdentityUtil.class, "threadLocalProperties", threadLocalProperties);
when(IdentityUtil.getServerURL("/oauth", true, true)).thenReturn("https://localhost:9443/oauth");
when(IdentityUtil.getServerURL("/oauth2", true, true)).thenReturn("https://localhost:9443/oauth2");

oAuthAppTenantResolverValve = spy(new OAuthAppTenantResolverValve());
IdentityUtil.threadLocalProperties.get().remove(TENANT_NAME_FROM_CONTEXT);
}

private void invokeAppTenantResolverValve() throws IOException, ServletException {
Expand All @@ -108,18 +116,6 @@ private void invokeAppTenantResolverValve() throws IOException, ServletException
oAuthAppTenantResolverValve.invoke(request, response);
}

@Test
public void testInvokeWhenTenantQualifiedUrlsEnabled() throws Exception {

// Suppress the execution of cleaning methods inorder to assert the correct behaviour.
PowerMockito.suppress(PowerMockito.method(
OAuthAppTenantResolverValve.class, "unsetThreadLocalContextTenantName"));

when(IdentityTenantUtil.isTenantQualifiedUrlsEnabled()).thenReturn(true);
invokeAppTenantResolverValve();
Assert.assertNull(IdentityUtil.threadLocalProperties.get().get(TENANT_NAME_FROM_CONTEXT));
}

@DataProvider
public Object[][] invokeDataProvider() {

Expand All @@ -142,7 +138,7 @@ public void testInvoke(String requestPath, String clientIdParam, String[] header
PowerMockito.suppress(PowerMockito.method(
OAuthAppTenantResolverValve.class, "unsetThreadLocalContextTenantName"));

when(request.getRequestURI()).thenReturn(requestPath);
when(request.getRequestURL()).thenReturn(new StringBuffer(requestPath));
if (requestPath.startsWith("/oauth/")) {
when(request.getParameter(OAUTH_CONSUMER_KEY)).thenReturn(clientIdParam);
} else {
Expand All @@ -167,12 +163,13 @@ public void testInvoke(String requestPath, String clientIdParam, String[] header
}

@Test
public void testInvokeWithException(Exception expectedException) throws Exception {
public void testInvokeWithException() throws Exception {

// Suppress the execution of cleaning methods inorder to assert the correct behaviour.
PowerMockito.suppress(PowerMockito.method(
OAuthAppTenantResolverValve.class, "unsetThreadLocalContextTenantName"));

when(request.getRequestURL()).thenReturn(new StringBuffer(DUMMY_RESOURCE_OAUTH_2));
when(OAuth2Util.isBasicAuthorizationHeaderExists(request)).thenReturn(true);
when(OAuth2Util.extractCredentialsFromAuthzHeader(request)).thenThrow(
new OAuthClientAuthnException("error.message", "error.code"));
Expand All @@ -183,7 +180,7 @@ public void testInvokeWithException(Exception expectedException) throws Exceptio
@Test
public void testInvokeWithUnsetThreadLocal() throws Exception {

when(request.getRequestURI()).thenReturn(DUMMY_RESOURCE_OAUTH_2);
when(request.getRequestURL()).thenReturn(new StringBuffer(DUMMY_RESOURCE_OAUTH_2));
when(request.getParameter(CLIENT_ID)).thenReturn(CLIENT_ID);

when(OAuth2Util.getAppInformationByClientIdOnly(DUMMY_CLIENT_ID)).thenReturn(oAuthAppDO);
Expand Down

0 comments on commit 26a76a1

Please sign in to comment.