diff --git a/io.asgardeo.tomcat.oidc.agent/src/main/java/io/asgardeo/tomcat/oidc/agent/OIDCAgentFilter.java b/io.asgardeo.tomcat.oidc.agent/src/main/java/io/asgardeo/tomcat/oidc/agent/OIDCAgentFilter.java index fbc28c1..d9e0d86 100644 --- a/io.asgardeo.tomcat.oidc.agent/src/main/java/io/asgardeo/tomcat/oidc/agent/OIDCAgentFilter.java +++ b/io.asgardeo.tomcat.oidc.agent/src/main/java/io/asgardeo/tomcat/oidc/agent/OIDCAgentFilter.java @@ -21,10 +21,12 @@ import com.nimbusds.oauth2.sdk.util.StringUtils; import io.asgardeo.java.oidc.sdk.HTTPSessionBasedOIDCProcessor; import io.asgardeo.java.oidc.sdk.SSOAgentConstants; +import io.asgardeo.java.oidc.sdk.bean.RequestContext; import io.asgardeo.java.oidc.sdk.bean.SessionContext; import io.asgardeo.java.oidc.sdk.config.model.OIDCAgentConfig; import io.asgardeo.java.oidc.sdk.exception.SSOAgentClientException; import io.asgardeo.java.oidc.sdk.exception.SSOAgentException; +import io.asgardeo.java.oidc.sdk.exception.SSOAgentServerException; import io.asgardeo.java.oidc.sdk.request.OIDCRequestResolver; import org.apache.logging.log4j.Level; import org.apache.logging.log4j.LogManager; @@ -106,6 +108,7 @@ public void doFilter(ServletRequest servletRequest, ServletResponse servletRespo } if (requestResolver.isCallbackResponse()) { + RequestContext requestContext = getRequestContext(request); try { oidcManager.handleOIDCCallback(request, response); } catch (SSOAgentException e) { @@ -117,8 +120,21 @@ public void doFilter(ServletRequest servletRequest, ServletResponse servletRespo response.sendRedirect(oidcAgentConfig.getIndexPage()); return; } - response.sendRedirect("home.jsp"); - return; + String homePage = "home.jsp"; + if (StringUtils.isNotBlank(oidcAgentConfig.getHomePage())) { + homePage = oidcAgentConfig.getHomePage(); + response.sendRedirect(homePage); + return; + } + if (requestContext != null) { + if (requestContext.getParameter(SSOAgentConstants.REDIRECT_URI_KEY) != null) { + homePage = requestContext.getParameter(SSOAgentConstants.REDIRECT_URI_KEY).toString(); + response.sendRedirect(homePage); + return; + } + response.sendRedirect(homePage); + return; + } } if (!isActiveSessionPresent(request)) { @@ -132,6 +148,16 @@ public void doFilter(ServletRequest servletRequest, ServletResponse servletRespo } } + private RequestContext getRequestContext(HttpServletRequest request) throws SSOAgentServerException { + + HttpSession session = request.getSession(false); + + if (session != null && session.getAttribute(SSOAgentConstants.REQUEST_CONTEXT) != null) { + return (RequestContext) request.getSession(false).getAttribute(SSOAgentConstants.REQUEST_CONTEXT); + } + throw new SSOAgentServerException("Request context null."); + } + @Override public void destroy() {