diff --git a/doc/setup.md b/doc/setup.md index f5555e8fe..81dad0f81 100644 --- a/doc/setup.md +++ b/doc/setup.md @@ -172,10 +172,35 @@ termit: TermIt can operate in two authentication modes: -1. Internal authentication means -2. [Keycloak](https://www.keycloak.org/) -based +1. Internal authentication +2. OAuth2 based (e.g. [Keycloak](https://www.keycloak.org/)) + +By default, OAuth2 is disabled and internal authentication is used +To enable it, set termit security provider to `oidc` +and provide issuer-uri and jwk-set-uri. + +**`application.yml` example:** +```yml +spring: + security: + oauth2: + resourceserver: + jwt: + issuer-uri: http://keycloak.lan/realms/termit + jwk-set-uri: http://keycloak.lan/realms/termit/protocol/openid-connect/certs +termit: + security: + provider: "oidc" +``` + +**Environmental variables example:** +``` +SPRING_SECURITY_OAUTH2_RESOURCESERVER_JWT_ISSUERURI=http://keycloak.lan/realms/termit +SPRING_SECURITY_OAUTH2_RESOURCESERVER_JWT_JWKSETURI=http://keycloak.lan/realms/termit/protocol/openid-connect/certs +TERMIT_SECURITY_PROVIDER=oidc +``` + +TermIt will automatically configure its security accordingly +(it is using Spring's [`ConditionalOnProperty`](https://www.baeldung.com/spring-conditionalonproperty)). -By default, Keycloak is disabled (see `keycloak.enabled` in `application.yml`). To enable it, set `keycloak.enabled` to `true` and -provide additional required Keycloak parameters - see the [Keycloak Spring Boot integration docs](https://www.keycloak.org/docs/latest/securing_apps/#_spring_boot_adapter). -TermIt will automatically configure its security (it is using Spring's [`ConditionalOnProperty`](https://www.baeldung.com/spring-conditionalonproperty)) -accordingly. +**Note that termit-ui needs to be configured for mathcing authentication mode.** diff --git a/src/main/java/cz/cvut/kbss/termit/config/OAuth2SecurityConfig.java b/src/main/java/cz/cvut/kbss/termit/config/OAuth2SecurityConfig.java index ed5d66bde..db6435aa2 100644 --- a/src/main/java/cz/cvut/kbss/termit/config/OAuth2SecurityConfig.java +++ b/src/main/java/cz/cvut/kbss/termit/config/OAuth2SecurityConfig.java @@ -19,39 +19,29 @@ import cz.cvut.kbss.termit.security.AuthenticationSuccess; import cz.cvut.kbss.termit.security.HierarchicalRoleBasedAuthorityMapper; -import cz.cvut.kbss.termit.security.JwtUtils; import cz.cvut.kbss.termit.security.SecurityConstants; -import cz.cvut.kbss.termit.security.WebSocketJwtAuthorizationInterceptor; -import cz.cvut.kbss.termit.service.security.TermItUserDetailsService; import cz.cvut.kbss.termit.util.oidc.OidcGrantedAuthoritiesExtractor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; -import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; -import org.springframework.context.annotation.Scope; import org.springframework.core.convert.converter.Converter; -import org.springframework.messaging.Message; -import org.springframework.messaging.simp.SimpMessageType; -import org.springframework.messaging.simp.annotation.support.SimpAnnotationMethodMessageHandler; import org.springframework.security.authentication.AbstractAuthenticationToken; -import org.springframework.security.authorization.AuthorizationManager; import org.springframework.security.config.annotation.method.configuration.EnableMethodSecurity; import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; import org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer; -import org.springframework.security.config.annotation.web.socket.EnableWebSocketSecurity; import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.core.session.SessionRegistryImpl; -import org.springframework.security.messaging.access.intercept.MessageMatcherDelegatingAuthorizationManager; import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtDecoder; +import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationProvider; import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationToken; import org.springframework.security.web.SecurityFilterChain; import org.springframework.security.web.authentication.session.RegisterSessionAuthenticationStrategy; import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy; -import org.springframework.util.AntPathMatcher; import org.springframework.web.cors.CorsConfigurationSource; import java.util.Collection; @@ -96,6 +86,17 @@ public SecurityFilterChain filterChain(HttpSecurity http) throws Exception { return http.build(); } + /** + * Supplies auth provider which is not exposed by HttpSecurity + * @see cz.cvut.kbss.termit.security.WebSocketJwtAuthorizationInterceptor + */ + @Bean + public JwtAuthenticationProvider jwtAuthenticationProvider(JwtDecoder jwtDecoder) { + final JwtAuthenticationProvider provider = new JwtAuthenticationProvider(jwtDecoder); + provider.setJwtAuthenticationConverter(grantedAuthoritiesExtractor()); + return provider; + } + private CorsConfigurationSource corsConfigurationSource() { return SecurityConfig.createCorsConfiguration(config.getCors()); } @@ -108,35 +109,4 @@ private Converter grantedAuthoritiesExtractor( new HierarchicalRoleBasedAuthorityMapper().mapAuthorities(authorities)); }; } - - /** - * Part of {@link EnableWebSocketSecurity @EnableWebSocketSecurity} replacement - * - * @see WebSocketConfig - */ - @Bean - @Scope("prototype") - public MessageMatcherDelegatingAuthorizationManager.Builder messageAuthorizationManagerBuilder( - ApplicationContext context) { - return MessageMatcherDelegatingAuthorizationManager.builder().simpDestPathMatcher( - () -> (context.getBeanNamesForType(SimpAnnotationMethodMessageHandler.class).length > 0) - ? context.getBean(SimpAnnotationMethodMessageHandler.class).getPathMatcher() - : new AntPathMatcher()); - } - - /** - * WebSocket endpoint authorization - */ - @Bean - public AuthorizationManager> messageAuthorizationManager( - MessageMatcherDelegatingAuthorizationManager.Builder messages) { - return messages.simpTypeMatchers(SimpMessageType.DISCONNECT).permitAll() - .anyMessage().authenticated().build(); - } - - @Bean - public WebSocketJwtAuthorizationInterceptor webSocketJwtAuthorizationInterceptor(JwtUtils jwtUtils, - TermItUserDetailsService userDetailsService) { - return new WebSocketJwtAuthorizationInterceptor(jwtUtils, userDetailsService); - } } diff --git a/src/main/java/cz/cvut/kbss/termit/config/SecurityConfig.java b/src/main/java/cz/cvut/kbss/termit/config/SecurityConfig.java index 6e17751a6..aa14405ec 100644 --- a/src/main/java/cz/cvut/kbss/termit/config/SecurityConfig.java +++ b/src/main/java/cz/cvut/kbss/termit/config/SecurityConfig.java @@ -23,36 +23,31 @@ import cz.cvut.kbss.termit.security.JwtAuthorizationFilter; import cz.cvut.kbss.termit.security.JwtUtils; import cz.cvut.kbss.termit.security.SecurityConstants; -import cz.cvut.kbss.termit.security.WebSocketJwtAuthorizationInterceptor; +import cz.cvut.kbss.termit.security.TermitJwtDecoder; import cz.cvut.kbss.termit.service.security.TermItUserDetailsService; import cz.cvut.kbss.termit.util.Constants; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; -import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; -import org.springframework.context.annotation.Scope; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; -import org.springframework.messaging.Message; -import org.springframework.messaging.simp.SimpMessageType; -import org.springframework.messaging.simp.annotation.support.SimpAnnotationMethodMessageHandler; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.authentication.AuthenticationProvider; -import org.springframework.security.authorization.AuthorizationManager; import org.springframework.security.config.annotation.authentication.builders.AuthenticationManagerBuilder; import org.springframework.security.config.annotation.method.configuration.EnableMethodSecurity; import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; import org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer; -import org.springframework.security.config.annotation.web.socket.EnableWebSocketSecurity; -import org.springframework.security.messaging.access.intercept.MessageMatcherDelegatingAuthorizationManager; +import org.springframework.security.oauth2.jwt.JwtDecoder; +import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationConverter; +import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationProvider; +import org.springframework.security.oauth2.server.resource.authentication.JwtGrantedAuthoritiesConverter; import org.springframework.security.web.SecurityFilterChain; import org.springframework.security.web.authentication.AuthenticationFailureHandler; import org.springframework.security.web.authentication.HttpStatusEntryPoint; -import org.springframework.util.AntPathMatcher; import org.springframework.web.cors.CorsConfiguration; import org.springframework.web.cors.CorsConfigurationSource; import org.springframework.web.cors.UrlBasedCorsConfigurationSource; @@ -100,7 +95,7 @@ public SecurityConfig(AuthenticationProvider authenticationProvider, } @Bean - public SecurityFilterChain filterChain(HttpSecurity http) throws Exception { + public SecurityFilterChain filterChain(HttpSecurity http, TermitJwtDecoder jwtDecoder) throws Exception { LOG.debug("Using internal security mechanisms."); final AuthenticationManager authManager = buildAuthenticationManager(http); http.authorizeHttpRequests((auth) -> auth.requestMatchers(antMatcher("/rest/query")).permitAll() @@ -112,7 +107,7 @@ public SecurityFilterChain filterChain(HttpSecurity http) throws Exception { .logoutSuccessHandler(authenticationSuccessHandler)) .authenticationManager(authManager) .addFilter(authenticationFilter(authManager)) - .addFilter(new JwtAuthorizationFilter(authManager, jwtUtils, userDetailsService, objectMapper)); + .addFilter(new JwtAuthorizationFilter(authManager, jwtUtils, objectMapper, jwtDecoder)); return http.build(); } @@ -131,6 +126,22 @@ private JwtAuthenticationFilter authenticationFilter(AuthenticationManager authe return authenticationFilter; } + /** + * @see cz.cvut.kbss.termit.security.WebSocketJwtAuthorizationInterceptor + */ + @Bean + public JwtAuthenticationProvider jwtAuthenticationProvider(JwtDecoder jwtDecoder) { + final JwtGrantedAuthoritiesConverter authoritiesConverter = new JwtGrantedAuthoritiesConverter(); + authoritiesConverter.setAuthorityPrefix(""); // this removes default "SCOPE_" prefix + // otherwise, all granted authorities would have this prefix + // (like "SCOPE_ROLE_RESTRICTED_USER", we want just ROLE_...) + final JwtAuthenticationConverter converter = new JwtAuthenticationConverter(); + converter.setJwtGrantedAuthoritiesConverter(authoritiesConverter); + final JwtAuthenticationProvider provider = new JwtAuthenticationProvider(jwtDecoder); + provider.setJwtAuthenticationConverter(converter); + return provider; + } + private CorsConfigurationSource corsConfigurationSource() { return createCorsConfiguration(config.getCors()); } @@ -154,32 +165,8 @@ protected static CorsConfigurationSource createCorsConfiguration( return source; } - /** - * Part of {@link EnableWebSocketSecurity @EnableWebSocketSecurity} replacement - * @see WebSocketConfig - */ - @Bean - @Scope("prototype") - public MessageMatcherDelegatingAuthorizationManager.Builder messageAuthorizationManagerBuilder( - ApplicationContext context) { - return MessageMatcherDelegatingAuthorizationManager.builder().simpDestPathMatcher( - () -> (context.getBeanNamesForType(SimpAnnotationMethodMessageHandler.class).length > 0) - ? context.getBean(SimpAnnotationMethodMessageHandler.class).getPathMatcher() - : new AntPathMatcher()); - } - - /** - * WebSocket endpoint authorization - */ - @Bean - public AuthorizationManager> messageAuthorizationManager( - MessageMatcherDelegatingAuthorizationManager.Builder messages) { - return messages.simpTypeMatchers(SimpMessageType.DISCONNECT).permitAll() - .anyMessage().authenticated().build(); - } - @Bean - public WebSocketJwtAuthorizationInterceptor webSocketJwtAuthorizationInterceptor() { - return new WebSocketJwtAuthorizationInterceptor(jwtUtils, userDetailsService); + public TermitJwtDecoder jwtDecoder() { + return new TermitJwtDecoder(jwtUtils, userDetailsService); } } diff --git a/src/main/java/cz/cvut/kbss/termit/config/WebSocketConfig.java b/src/main/java/cz/cvut/kbss/termit/config/WebSocketConfig.java index a0ad16d27..36f15d990 100644 --- a/src/main/java/cz/cvut/kbss/termit/config/WebSocketConfig.java +++ b/src/main/java/cz/cvut/kbss/termit/config/WebSocketConfig.java @@ -1,41 +1,25 @@ package cz.cvut.kbss.termit.config; import com.fasterxml.jackson.databind.ObjectMapper; -import cz.cvut.kbss.termit.security.WebSocketJwtAuthorizationInterceptor; -import cz.cvut.kbss.termit.util.Constants; -import cz.cvut.kbss.termit.websocket.handler.StompExceptionHandler; -import cz.cvut.kbss.termit.websocket.handler.WebSocketMessageWithHeadersValueHandler; -import org.jetbrains.annotations.NotNull; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; -import org.springframework.context.annotation.Lazy; +import org.springframework.context.annotation.Scope; import org.springframework.core.Ordered; import org.springframework.core.annotation.Order; import org.springframework.messaging.Message; import org.springframework.messaging.converter.MappingJackson2MessageConverter; -import org.springframework.messaging.converter.MessageConverter; import org.springframework.messaging.converter.StringMessageConverter; -import org.springframework.messaging.handler.invocation.HandlerMethodArgumentResolver; -import org.springframework.messaging.handler.invocation.HandlerMethodReturnValueHandler; -import org.springframework.messaging.simp.SimpMessagingTemplate; -import org.springframework.messaging.simp.config.ChannelRegistration; -import org.springframework.messaging.simp.config.MessageBrokerRegistry; +import org.springframework.messaging.simp.SimpMessageType; +import org.springframework.messaging.simp.annotation.support.SimpAnnotationMethodMessageHandler; import org.springframework.security.authorization.AuthorizationManager; -import org.springframework.security.authorization.SpringAuthorizationEventPublisher; import org.springframework.security.config.annotation.web.socket.EnableWebSocketSecurity; -import org.springframework.security.messaging.access.intercept.AuthorizationChannelInterceptor; -import org.springframework.security.messaging.context.AuthenticationPrincipalArgumentResolver; -import org.springframework.security.messaging.context.SecurityContextChannelInterceptor; -import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker; -import org.springframework.web.socket.config.annotation.StompEndpointRegistry; -import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer; -import org.springframework.web.socket.config.annotation.WebSocketTransportRegistration; +import org.springframework.security.messaging.access.intercept.MessageMatcherDelegatingAuthorizationManager; +import org.springframework.util.AntPathMatcher; import java.nio.charset.StandardCharsets; -import java.util.List; /* We are not using @EnableWebSocketSecurity @@ -43,94 +27,46 @@ it automatically requires CSRF which cannot be configured (disabled) at the mome (will probably change in the future) */ @Configuration -@EnableWebSocketMessageBroker -@Order(Ordered.HIGHEST_PRECEDENCE + 99) // ensures priority above Spring Security -public class WebSocketConfig implements WebSocketMessageBrokerConfigurer { - - private final cz.cvut.kbss.termit.util.Configuration configuration; +@Order(Ordered.HIGHEST_PRECEDENCE + 98) // ensures priority above Spring Security +public class WebSocketConfig { private final ApplicationContext context; - private final AuthorizationManager> messageAuthorizationManager; - - private final WebSocketJwtAuthorizationInterceptor jwtAuthorizationInterceptor; - private final ObjectMapper jsonLdMapper; - private final SimpMessagingTemplate simpMessagingTemplate; - @Autowired - public WebSocketConfig(cz.cvut.kbss.termit.util.Configuration configuration, ApplicationContext context, - AuthorizationManager> messageAuthorizationManager, - WebSocketJwtAuthorizationInterceptor jwtAuthorizationInterceptor, - @Qualifier("jsonLdMapper") ObjectMapper jsonLdMapper, - @Lazy SimpMessagingTemplate simpMessagingTemplate) { - this.configuration = configuration; + public WebSocketConfig(ApplicationContext context, @Qualifier("jsonLdMapper") ObjectMapper jsonLdMapper) { this.context = context; - this.messageAuthorizationManager = messageAuthorizationManager; - this.jwtAuthorizationInterceptor = jwtAuthorizationInterceptor; this.jsonLdMapper = jsonLdMapper; - this.simpMessagingTemplate = simpMessagingTemplate; } - /** - * WebSocket security setup (replaces {@link EnableWebSocketSecurity @EnableWebSocketSecurity}) - */ - @Override - public void addArgumentResolvers(List argumentResolvers) { - AuthenticationPrincipalArgumentResolver resolver = new AuthenticationPrincipalArgumentResolver(); - argumentResolvers.add(resolver); + @Bean + public StringMessageConverter termitStringMessageConverter() { + return new StringMessageConverter(StandardCharsets.UTF_8); + } + + @Bean + public MappingJackson2MessageConverter termitJsonLdMessageConverter() { + return new MappingJackson2MessageConverter(jsonLdMapper); } /** * WebSocket security setup (replaces {@link EnableWebSocketSecurity @EnableWebSocketSecurity}) - * @see Spring security source */ - @Override - public void configureClientInboundChannel(@NotNull ChannelRegistration registration) { - AuthorizationChannelInterceptor interceptor = new AuthorizationChannelInterceptor(this.messageAuthorizationManager); - interceptor.setAuthorizationEventPublisher(new SpringAuthorizationEventPublisher(this.context)); - registration.interceptors(jwtAuthorizationInterceptor, new SecurityContextChannelInterceptor(), interceptor); - } - - @Override - public void addReturnValueHandlers(List returnValueHandlers) { - returnValueHandlers.add(new WebSocketMessageWithHeadersValueHandler(simpMessagingTemplate)); - } - - @Override - public void registerStompEndpoints(StompEndpointRegistry registry) { - registry.addEndpoint("/ws").setAllowedOrigins(configuration.getCors().getAllowedOrigins().split(",")); - registry.setErrorHandler(new StompExceptionHandler()); - } - - @Override - public void configureMessageBroker(MessageBrokerRegistry registry) { - registry.setApplicationDestinationPrefixes("/") - .setUserDestinationPrefix("/user"); - } - - @Override - public void configureWebSocketTransport(WebSocketTransportRegistration registry) { - registry.setTimeToFirstMessage(Constants.WEBSOCKET_TIME_TO_FIRST_MESSAGE); - registry.setSendBufferSizeLimit(Constants.WEBSOCKET_SEND_BUFFER_SIZE_LIMIT); - } - - @Override - public boolean configureMessageConverters(List messageConverters) { - messageConverters.add(termitJsonLdMessageConverter()); - messageConverters.add(termitStringMessageConverter()); - return false; // do not add default converters - } - @Bean - public MessageConverter termitStringMessageConverter() { - return new StringMessageConverter(StandardCharsets.UTF_8); + @Scope("prototype") + public MessageMatcherDelegatingAuthorizationManager.Builder messageAuthorizationManagerBuilder() { + return MessageMatcherDelegatingAuthorizationManager.builder() + .simpDestPathMatcher(() -> (context.getBeanNamesForType(SimpAnnotationMethodMessageHandler.class).length > 0) ? context.getBean(SimpAnnotationMethodMessageHandler.class) + .getPathMatcher() : new AntPathMatcher()); } + /** + * WebSocket endpoint authorization + */ @Bean - public MessageConverter termitJsonLdMessageConverter() { - return new MappingJackson2MessageConverter(jsonLdMapper); + public AuthorizationManager> messageAuthorizationManager() { + return messageAuthorizationManagerBuilder().simpTypeMatchers(SimpMessageType.DISCONNECT).permitAll() + .anyMessage().authenticated().build(); } - } diff --git a/src/main/java/cz/cvut/kbss/termit/config/WebSocketMessageBrokerConfig.java b/src/main/java/cz/cvut/kbss/termit/config/WebSocketMessageBrokerConfig.java new file mode 100644 index 000000000..ceefa1273 --- /dev/null +++ b/src/main/java/cz/cvut/kbss/termit/config/WebSocketMessageBrokerConfig.java @@ -0,0 +1,126 @@ +package cz.cvut.kbss.termit.config; + +import cz.cvut.kbss.termit.security.WebSocketJwtAuthorizationInterceptor; +import cz.cvut.kbss.termit.util.Constants; +import cz.cvut.kbss.termit.websocket.handler.StompExceptionHandler; +import cz.cvut.kbss.termit.websocket.handler.WebSocketExceptionHandler; +import cz.cvut.kbss.termit.websocket.handler.WebSocketMessageWithHeadersValueHandler; +import org.jetbrains.annotations.NotNull; +import org.springframework.context.ApplicationContext; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Lazy; +import org.springframework.core.Ordered; +import org.springframework.core.annotation.Order; +import org.springframework.messaging.Message; +import org.springframework.messaging.converter.MappingJackson2MessageConverter; +import org.springframework.messaging.converter.MessageConverter; +import org.springframework.messaging.converter.StringMessageConverter; +import org.springframework.messaging.handler.invocation.HandlerMethodArgumentResolver; +import org.springframework.messaging.handler.invocation.HandlerMethodReturnValueHandler; +import org.springframework.messaging.simp.SimpMessagingTemplate; +import org.springframework.messaging.simp.config.ChannelRegistration; +import org.springframework.messaging.simp.config.MessageBrokerRegistry; +import org.springframework.security.authorization.AuthorizationManager; +import org.springframework.security.authorization.SpringAuthorizationEventPublisher; +import org.springframework.security.config.annotation.web.socket.EnableWebSocketSecurity; +import org.springframework.security.messaging.access.intercept.AuthorizationChannelInterceptor; +import org.springframework.security.messaging.context.AuthenticationPrincipalArgumentResolver; +import org.springframework.security.messaging.context.SecurityContextChannelInterceptor; +import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker; +import org.springframework.web.socket.config.annotation.StompEndpointRegistry; +import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer; +import org.springframework.web.socket.config.annotation.WebSocketTransportRegistration; + +import java.util.List; + +@Configuration +@EnableWebSocketMessageBroker +@Order(Ordered.HIGHEST_PRECEDENCE + 99) // ensures priority above Spring Security +public class WebSocketMessageBrokerConfig implements WebSocketMessageBrokerConfigurer { + + private final AuthorizationManager> messageAuthorizationManager; + + private final ApplicationContext context; + + private final WebSocketJwtAuthorizationInterceptor webSocketJwtAuthorizationInterceptor; + + private final SimpMessagingTemplate simpMessagingTemplate; + + private final String allowedOrigins; + + private final StringMessageConverter termitStringMessageConverter; + + private final MappingJackson2MessageConverter termitJsonLdMessageConverter; + + private final WebSocketExceptionHandler webSocketExceptionHandler; + + public WebSocketMessageBrokerConfig(AuthorizationManager> messageAuthorizationManager, + ApplicationContext context, + WebSocketJwtAuthorizationInterceptor webSocketJwtAuthorizationInterceptor, + @Lazy SimpMessagingTemplate simpMessagingTemplate, + StringMessageConverter termitStringMessageConverter, + MappingJackson2MessageConverter termitJsonLdMessageConverter, + cz.cvut.kbss.termit.util.Configuration configuration, + WebSocketExceptionHandler webSocketExceptionHandler) { + this.messageAuthorizationManager = messageAuthorizationManager; + this.context = context; + this.webSocketJwtAuthorizationInterceptor = webSocketJwtAuthorizationInterceptor; + this.simpMessagingTemplate = simpMessagingTemplate; + this.termitStringMessageConverter = termitStringMessageConverter; + this.termitJsonLdMessageConverter = termitJsonLdMessageConverter; + + this.allowedOrigins = configuration.getCors().getAllowedOrigins(); + this.webSocketExceptionHandler = webSocketExceptionHandler; + } + + /** + * WebSocket security setup (replaces {@link EnableWebSocketSecurity @EnableWebSocketSecurity}) + */ + @Override + public void addArgumentResolvers(List argumentResolvers) { + AuthenticationPrincipalArgumentResolver resolver = new AuthenticationPrincipalArgumentResolver(); + argumentResolvers.add(resolver); + } + + /** + * WebSocket security setup (replaces {@link EnableWebSocketSecurity @EnableWebSocketSecurity}) + * + * @see Spring security source + */ + @Override + public void configureClientInboundChannel(@NotNull ChannelRegistration registration) { + AuthorizationChannelInterceptor interceptor = new AuthorizationChannelInterceptor(messageAuthorizationManager); + interceptor.setAuthorizationEventPublisher(new SpringAuthorizationEventPublisher(context)); + registration.interceptors(webSocketJwtAuthorizationInterceptor, new SecurityContextChannelInterceptor(), interceptor); + } + + @Override + public void addReturnValueHandlers(List returnValueHandlers) { + returnValueHandlers.add(new WebSocketMessageWithHeadersValueHandler(simpMessagingTemplate)); + } + + @Override + public void registerStompEndpoints(StompEndpointRegistry registry) { + registry.addEndpoint("/ws").setAllowedOrigins(allowedOrigins.split(",")); + registry.setErrorHandler(new StompExceptionHandler(webSocketExceptionHandler)); + } + + @Override + public void configureMessageBroker(MessageBrokerRegistry registry) { + registry.setApplicationDestinationPrefixes("/") + .setUserDestinationPrefix("/user"); + } + + @Override + public void configureWebSocketTransport(WebSocketTransportRegistration registry) { + registry.setTimeToFirstMessage(Constants.WEBSOCKET_TIME_TO_FIRST_MESSAGE); + registry.setSendBufferSizeLimit(Constants.WEBSOCKET_SEND_BUFFER_SIZE_LIMIT); + } + + @Override + public boolean configureMessageConverters(List messageConverters) { + messageConverters.add(termitJsonLdMessageConverter); + messageConverters.add(termitStringMessageConverter); + return false; // do not add default converters + } +} diff --git a/src/main/java/cz/cvut/kbss/termit/rest/handler/RestExceptionHandler.java b/src/main/java/cz/cvut/kbss/termit/rest/handler/RestExceptionHandler.java index 90eb63cac..03d50a199 100644 --- a/src/main/java/cz/cvut/kbss/termit/rest/handler/RestExceptionHandler.java +++ b/src/main/java/cz/cvut/kbss/termit/rest/handler/RestExceptionHandler.java @@ -39,17 +39,18 @@ import cz.cvut.kbss.termit.exception.WebServiceIntegrationException; import cz.cvut.kbss.termit.exception.importing.UnsupportedImportMediaTypeException; import cz.cvut.kbss.termit.exception.importing.VocabularyImportException; +import jakarta.servlet.http.HttpServletRequest; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.http.HttpStatus; import org.springframework.http.ResponseEntity; +import org.springframework.security.access.AccessDeniedException; +import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.userdetails.UsernameNotFoundException; import org.springframework.web.bind.annotation.ExceptionHandler; import org.springframework.web.bind.annotation.RestControllerAdvice; import org.springframework.web.multipart.MaxUploadSizeExceededException; -import jakarta.servlet.http.HttpServletRequest; - /** * Exception handlers for REST controllers. *

@@ -127,6 +128,27 @@ public ResponseEntity authorizationException(HttpServletRequest reque return new ResponseEntity<>(errorInfo(request, e), HttpStatus.FORBIDDEN); } + @ExceptionHandler(AuthenticationException.class) + public ResponseEntity authenticationException(HttpServletRequest request, AuthenticationException e) { + LOG.warn("Authentication failure during HTTP request to {}: {}", request.getRequestURI(), e.getMessage()); + LOG.atDebug().setCause(e).log(e.getMessage()); + return new ResponseEntity<>(errorInfo(request, e), HttpStatus.FORBIDDEN); + } + + /** + * Fired, for example, on method security violation + */ + @ExceptionHandler(AccessDeniedException.class) + public ResponseEntity accessDeniedException(HttpServletRequest request, AccessDeniedException e) { + LOG.atWarn().setMessage("[{}] Unauthorized access: {}").addArgument(() -> { + if (request.getUserPrincipal() != null) { + return request.getUserPrincipal().getName(); + } + return "(unknown user)"; + }).addArgument(e.getMessage()).log(); + return new ResponseEntity<>(errorInfo(request, e), HttpStatus.FORBIDDEN); + } + @ExceptionHandler(ValidationException.class) public ResponseEntity validationException(HttpServletRequest request, ValidationException e) { logException(e, request); diff --git a/src/main/java/cz/cvut/kbss/termit/security/JwtAuthorizationFilter.java b/src/main/java/cz/cvut/kbss/termit/security/JwtAuthorizationFilter.java index 06e87770b..12143d6dc 100644 --- a/src/main/java/cz/cvut/kbss/termit/security/JwtAuthorizationFilter.java +++ b/src/main/java/cz/cvut/kbss/termit/security/JwtAuthorizationFilter.java @@ -24,19 +24,20 @@ import cz.cvut.kbss.termit.rest.handler.ErrorInfo; import cz.cvut.kbss.termit.security.model.TermItUserDetails; import cz.cvut.kbss.termit.service.security.SecurityUtils; -import cz.cvut.kbss.termit.service.security.TermItUserDetailsService; +import jakarta.servlet.FilterChain; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.authentication.DisabledException; import org.springframework.security.authentication.LockedException; import org.springframework.security.core.userdetails.UsernameNotFoundException; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtClaimNames; import org.springframework.security.web.authentication.www.BasicAuthenticationFilter; -import jakarta.servlet.FilterChain; -import jakarta.servlet.ServletException; -import jakarta.servlet.http.HttpServletRequest; -import jakarta.servlet.http.HttpServletResponse; import java.io.IOException; import java.util.Arrays; import java.util.HashSet; @@ -60,16 +61,16 @@ public class JwtAuthorizationFilter extends BasicAuthenticationFilter { private final JwtUtils jwtUtils; - private final TermItUserDetailsService userDetailsService; - private final ObjectMapper objectMapper; - public JwtAuthorizationFilter(AuthenticationManager authenticationManager, JwtUtils jwtUtils, - TermItUserDetailsService userDetailsService, ObjectMapper objectMapper) { + private final TermitJwtDecoder jwtDecoder; + + public JwtAuthorizationFilter(AuthenticationManager authenticationManager, JwtUtils jwtUtils, ObjectMapper objectMapper, + TermitJwtDecoder jwtDecoder) { super(authenticationManager); this.jwtUtils = jwtUtils; - this.userDetailsService = userDetailsService; this.objectMapper = objectMapper; + this.jwtDecoder = jwtDecoder; } @Override @@ -82,13 +83,16 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse } final String authToken = authHeader.substring(SecurityConstants.JWT_TOKEN_PREFIX.length()); try { - final TermItUserDetails userDetails = jwtUtils.extractUserInfo(authToken); - final TermItUserDetails existingDetails = userDetailsService.loadUserByUsername(userDetails.getUsername()); - SecurityUtils.verifyAccountStatus(existingDetails.getUser()); - SecurityUtils.setCurrentUser(existingDetails); - refreshToken(authToken, response); - chain.doFilter(request, response); - } catch (JwtException e) { + Jwt jwt = jwtDecoder.decode(authToken); + final Object principal = jwt.getClaim(JwtClaimNames.SUB); + if (principal instanceof TermItUserDetails existingDetails) { + SecurityUtils.setCurrentUser(existingDetails); + refreshToken(authToken, response); + chain.doFilter(request, response); + } else { + throw new JwtException("Invalid JWT token contents"); + } + } catch (JwtException | org.springframework.security.oauth2.jwt.JwtException e) { if (shouldAllowThroughUnauthenticated(request)) { chain.doFilter(request, response); } else { diff --git a/src/main/java/cz/cvut/kbss/termit/security/JwtUtils.java b/src/main/java/cz/cvut/kbss/termit/security/JwtUtils.java index 54ca57bdb..3d8b52c4c 100644 --- a/src/main/java/cz/cvut/kbss/termit/security/JwtUtils.java +++ b/src/main/java/cz/cvut/kbss/termit/security/JwtUtils.java @@ -27,6 +27,8 @@ import cz.cvut.kbss.termit.util.Utils; import io.jsonwebtoken.Claims; import io.jsonwebtoken.ExpiredJwtException; +import io.jsonwebtoken.Jws; +import io.jsonwebtoken.JwtParser; import io.jsonwebtoken.Jwts; import io.jsonwebtoken.MalformedJwtException; import io.jsonwebtoken.SignatureAlgorithm; @@ -35,6 +37,7 @@ import io.jsonwebtoken.jackson.io.JacksonSerializer; import io.jsonwebtoken.security.Keys; import io.jsonwebtoken.security.SecurityException; +import org.jetbrains.annotations.NotNull; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.security.core.GrantedAuthority; @@ -63,11 +66,16 @@ public class JwtUtils { private final Key key; + private final JwtParser jwtParser; + @Autowired public JwtUtils(@Qualifier("objectMapper") ObjectMapper objectMapper, Configuration config) { this.objectMapper = objectMapper; this.key = Utils.isBlank(config.getJwt().getSecretKey()) ? Keys.secretKeyFor(SIGNATURE_ALGORITHM) : Keys.hmacShaKeyFor(config.getJwt().getSecretKey().getBytes(StandardCharsets.UTF_8)); + this.jwtParser = Jwts.parserBuilder().setSigningKey(key) + .deserializeJsonWith(new JacksonDeserializer<>(objectMapper)) + .build(); } /** @@ -109,7 +117,16 @@ private static String mapAuthoritiesToClaim(Collection getClaimsFromToken(String token) { try { return parseClaims(token); } catch (MalformedJwtException | UnsupportedJwtException e) { @@ -133,10 +150,8 @@ private Claims getClaimsFromToken(String token) { } } - private Claims parseClaims(String token) { - return Jwts.parserBuilder().setSigningKey(key) - .deserializeJsonWith(new JacksonDeserializer<>(objectMapper)) - .build().parseClaimsJws(token).getBody(); + private Jws parseClaims(String token) { + return jwtParser.parseClaimsJws(token); } private static void verifyAttributePresence(Claims claims) { @@ -171,7 +186,7 @@ private static List mapClaimToAuthorities(String claim) { */ public String refreshToken(String token) { Objects.requireNonNull(token); - final Claims claims = getClaimsFromToken(token); + final Claims claims = getClaimsFromToken(token).getBody(); final Instant issued = issueTimestamp(); claims.setIssuedAt(Date.from(issued)); claims.setExpiration(Date.from(issued.plusMillis(SecurityConstants.SESSION_TIMEOUT))); @@ -191,7 +206,7 @@ public String refreshToken(String token) { */ public URI getUserUri(String token) { try { - final Claims claims = parseClaims(token); + final Claims claims = parseClaims(token).getBody(); return URI.create(claims.getId()); } catch (ExpiredJwtException e) { return URI.create(e.getClaims().getId()); @@ -206,7 +221,7 @@ public URI getUserUri(String token) { */ public Instant getTokenIssueTimestamp(String token) { try { - final Claims claims = parseClaims(token); + final Claims claims = parseClaims(token).getBody(); return claims.getIssuedAt().toInstant(); } catch (ExpiredJwtException e) { return e.getClaims().getIssuedAt().toInstant(); diff --git a/src/main/java/cz/cvut/kbss/termit/security/TermitJwtDecoder.java b/src/main/java/cz/cvut/kbss/termit/security/TermitJwtDecoder.java new file mode 100644 index 000000000..e46777b14 --- /dev/null +++ b/src/main/java/cz/cvut/kbss/termit/security/TermitJwtDecoder.java @@ -0,0 +1,60 @@ +package cz.cvut.kbss.termit.security; + +import cz.cvut.kbss.termit.security.model.TermItUserDetails; +import cz.cvut.kbss.termit.service.security.SecurityUtils; +import cz.cvut.kbss.termit.service.security.TermItUserDetailsService; +import io.jsonwebtoken.Claims; +import io.jsonwebtoken.Jws; +import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtClaimNames; +import org.springframework.security.oauth2.jwt.JwtException; + +import java.util.Objects; +import java.util.stream.Collectors; + +/** + * @see #decode(String) + */ +public class TermitJwtDecoder implements org.springframework.security.oauth2.jwt.JwtDecoder { + + private final JwtUtils jwtUtils; + + private final TermItUserDetailsService userDetailsService; + + public TermitJwtDecoder(JwtUtils jwtUtils, TermItUserDetailsService userDetailsService) { + this.jwtUtils = jwtUtils; + this.userDetailsService = userDetailsService; + } + + /** + * Decodes JWT token (without the {@code Bearer} prefix) + * and ensures its validity. + * @throws JwtException with cause, when token could not be decoded or verified + */ + @Override + public Jwt decode(String token) throws JwtException { + try { + final Jws expanded = jwtUtils.getClaimsFromToken(token); + Objects.requireNonNull(expanded); + Objects.requireNonNull(expanded.getBody()); + Objects.requireNonNull(expanded.getHeader()); + final Claims claims = expanded.getBody(); + Objects.requireNonNull(claims.getIssuedAt()); + Objects.requireNonNull(claims.getExpiration()); + final TermItUserDetails tokenDetails = jwtUtils.extractUserInfo(claims); + final TermItUserDetails existingDetails = userDetailsService.loadUserByUsername(tokenDetails.getUsername()); + + SecurityUtils.verifyAccountStatus(existingDetails.getUser()); + + claims.put("scope", existingDetails.getAuthorities().stream().map(GrantedAuthority::getAuthority) + .collect(Collectors.toSet())); + claims.put(JwtClaimNames.SUB, existingDetails); + + return new Jwt(token, claims.getIssuedAt().toInstant(), claims.getExpiration() + .toInstant(), expanded.getHeader(), claims); + } catch (cz.cvut.kbss.termit.exception.JwtException | NullPointerException e) { + throw new JwtException(e.getMessage(), e); + } + } +} diff --git a/src/main/java/cz/cvut/kbss/termit/security/WebSocketJwtAuthorizationInterceptor.java b/src/main/java/cz/cvut/kbss/termit/security/WebSocketJwtAuthorizationInterceptor.java index 71e5627de..eda4786a7 100644 --- a/src/main/java/cz/cvut/kbss/termit/security/WebSocketJwtAuthorizationInterceptor.java +++ b/src/main/java/cz/cvut/kbss/termit/security/WebSocketJwtAuthorizationInterceptor.java @@ -1,10 +1,5 @@ package cz.cvut.kbss.termit.security; -import cz.cvut.kbss.termit.exception.AuthorizationException; -import cz.cvut.kbss.termit.exception.JwtException; -import cz.cvut.kbss.termit.security.model.TermItUserDetails; -import cz.cvut.kbss.termit.service.security.SecurityUtils; -import cz.cvut.kbss.termit.service.security.TermItUserDetailsService; import org.jetbrains.annotations.NotNull; import org.springframework.http.HttpHeaders; import org.springframework.messaging.Message; @@ -13,25 +8,31 @@ import org.springframework.messaging.simp.stomp.StompHeaderAccessor; import org.springframework.messaging.support.ChannelInterceptor; import org.springframework.messaging.support.MessageHeaderAccessor; -import org.springframework.security.authentication.DisabledException; -import org.springframework.security.authentication.LockedException; +import org.springframework.security.authentication.AuthenticationCredentialsNotFoundException; import org.springframework.security.core.Authentication; -import org.springframework.security.core.userdetails.UsernameNotFoundException; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.server.resource.InvalidBearerTokenException; +import org.springframework.security.oauth2.server.resource.authentication.BearerTokenAuthenticationToken; +import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationProvider; +import org.springframework.stereotype.Component; +import org.springframework.util.StringUtils; /** - * Authorizes STOMP CONNECT messages + * Authenticates STOMP CONNECT messages *

- * Retrieves token from the {@code Authorization} header of STOMP message and validates JWT token. + * Retrieves token from the {@code Authorization} header + * and uses {@link JwtAuthenticationProvider} to authenticate the token. + * @see Consult this Stackoverflow answer */ +@Component public class WebSocketJwtAuthorizationInterceptor implements ChannelInterceptor { - private final JwtUtils jwtUtils; + private final JwtAuthenticationProvider jwtAuthenticationProvider; - private final TermItUserDetailsService userDetailsService; - - public WebSocketJwtAuthorizationInterceptor(JwtUtils jwtUtils, TermItUserDetailsService userDetailsService) { - this.jwtUtils = jwtUtils; - this.userDetailsService = userDetailsService; + public WebSocketJwtAuthorizationInterceptor(JwtAuthenticationProvider jwtAuthenticationProvider) { + this.jwtAuthenticationProvider = jwtAuthenticationProvider; } @Override @@ -39,27 +40,48 @@ public Message preSend(@NotNull Message message, @NotNull MessageChannel c StompHeaderAccessor headerAccessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class); if (headerAccessor != null && StompCommand.CONNECT.equals(headerAccessor.getCommand()) && headerAccessor.isMutable()) { final String authHeader = headerAccessor.getFirstNativeHeader(HttpHeaders.AUTHORIZATION); - if (authHeader != null && authHeader.startsWith(SecurityConstants.JWT_TOKEN_PREFIX)) { + if (authHeader != null) { headerAccessor.removeNativeHeader(HttpHeaders.AUTHORIZATION); - return process(message, authHeader, headerAccessor); + process(headerAccessor, authHeader); + return message; } - throw new AuthorizationException("Authorization header is invalid"); + throw new AuthenticationCredentialsNotFoundException("Invalid authorization header"); } return message; } - private Message process(final @NotNull Message message, final @NotNull String authHeader, - final @NotNull StompHeaderAccessor headerAccessor) { - final String authToken = authHeader.substring(SecurityConstants.JWT_TOKEN_PREFIX.length()); + /** + * Authenticates user using JWT token in authentication header + *

+ * According to Open ID spec, + * the token MUST be {@code Bearer}. + * And for example {@link org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationProvider} + * also supports only {@code Bearer} tokens. + */ + protected void process(StompHeaderAccessor stompHeaderAccessor, final @NotNull String authHeader) { + if (!StringUtils.startsWithIgnoreCase(authHeader, SecurityConstants.JWT_TOKEN_PREFIX)) { + throw new InvalidBearerTokenException("Invalid Bearer token in authorization header"); + } + + final String token = authHeader.substring(SecurityConstants.JWT_TOKEN_PREFIX.length()); + + BearerTokenAuthenticationToken authenticationRequest = new BearerTokenAuthenticationToken(token); + try { - final TermItUserDetails userDetails = jwtUtils.extractUserInfo(authToken); - final TermItUserDetails existingDetails = userDetailsService.loadUserByUsername(userDetails.getUsername()); - SecurityUtils.verifyAccountStatus(existingDetails.getUser()); - Authentication authentication = SecurityUtils.setCurrentUser(existingDetails); - headerAccessor.setUser(authentication); - return message; - } catch (JwtException | DisabledException | LockedException | UsernameNotFoundException e) { - throw new AuthorizationException(e.getMessage()); + Authentication authenticationResult = jwtAuthenticationProvider.authenticate(authenticationRequest); + if (authenticationResult != null && authenticationResult.isAuthenticated()) { + SecurityContext context = SecurityContextHolder.createEmptyContext(); + context.setAuthentication(authenticationResult); + SecurityContextHolder.setContext(context); + stompHeaderAccessor.setUser(authenticationResult); + return; // all ok + } + throw new OAuth2AuthenticationException("Authentication failed"); + } catch (Exception e) { + // ensure that context is cleared when any exception happens + stompHeaderAccessor.setUser(null); + SecurityContextHolder.clearContext(); + throw e; } } } diff --git a/src/main/java/cz/cvut/kbss/termit/service/security/SecurityUtils.java b/src/main/java/cz/cvut/kbss/termit/service/security/SecurityUtils.java index 758e82a33..18aa2993e 100644 --- a/src/main/java/cz/cvut/kbss/termit/service/security/SecurityUtils.java +++ b/src/main/java/cz/cvut/kbss/termit/service/security/SecurityUtils.java @@ -35,6 +35,7 @@ import org.springframework.security.crypto.password.PasswordEncoder; import org.springframework.security.oauth2.core.oidc.OidcUserInfo; import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtClaimNames; import org.springframework.stereotype.Service; import java.util.Objects; @@ -70,12 +71,17 @@ public SecurityUtils(UserDetailsService userDetailsService, PasswordEncoder pass public UserAccount getCurrentUser() { final SecurityContext context = SecurityContextHolder.getContext(); assert context != null && context.getAuthentication().isAuthenticated(); - if (context.getAuthentication().getPrincipal() instanceof Jwt) { + if (context.getAuthentication().getPrincipal() instanceof Jwt jwt) { + Object principal = jwt.getClaim(JwtClaimNames.SUB); + if(principal instanceof TermItUserDetails termItUserDetails) { + return termItUserDetails.getUser(); + } + return resolveAccountFromOAuthPrincipal(context); - } else { - final TermItUserDetails userDetails = (TermItUserDetails) context.getAuthentication().getDetails(); - return userDetails.getUser(); } + + final TermItUserDetails userDetails = (TermItUserDetails) context.getAuthentication().getDetails(); + return userDetails.getUser(); } private UserAccount resolveAccountFromOAuthPrincipal(SecurityContext context) { diff --git a/src/main/java/cz/cvut/kbss/termit/websocket/handler/StompExceptionHandler.java b/src/main/java/cz/cvut/kbss/termit/websocket/handler/StompExceptionHandler.java index 4f78bf920..4981eed7c 100644 --- a/src/main/java/cz/cvut/kbss/termit/websocket/handler/StompExceptionHandler.java +++ b/src/main/java/cz/cvut/kbss/termit/websocket/handler/StompExceptionHandler.java @@ -5,17 +5,75 @@ import org.slf4j.LoggerFactory; import org.springframework.messaging.Message; import org.springframework.messaging.simp.stomp.StompHeaderAccessor; +import org.springframework.messaging.support.MessageBuilder; import org.springframework.web.socket.messaging.StompSubProtocolErrorHandler; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; + +/** + * calls {@link WebSocketExceptionHandler} when possible, otherwise logs exception as error + */ public class StompExceptionHandler extends StompSubProtocolErrorHandler { private static final Logger LOG = LoggerFactory.getLogger(StompExceptionHandler.class); + private final WebSocketExceptionHandler webSocketExceptionHandler; + + public StompExceptionHandler(WebSocketExceptionHandler webSocketExceptionHandler) { + this.webSocketExceptionHandler = webSocketExceptionHandler; + } + @Override protected @NotNull Message handleInternal(@NotNull StompHeaderAccessor errorHeaderAccessor, - byte @NotNull [] errorPayload, - Throwable cause, StompHeaderAccessor clientHeaderAccessor) { - LOG.error("STOMP sub-protocol exception", cause); + byte @NotNull [] errorPayload, Throwable cause, + StompHeaderAccessor clientHeaderAccessor) { + final Message message = MessageBuilder.withPayload(errorPayload).setHeaders(errorHeaderAccessor).build(); + boolean handled = false; + try { + handled = delegate(message, cause); + } catch (InvocationTargetException e) { + LOG.error("Exception thrown during exception handler invocation", e); + } catch (IllegalAccessException unexpected) { + // is checked by delegate + } + + if (!handled) { + LOG.error("STOMP sub-protocol exception", cause); + } + return super.handleInternal(errorHeaderAccessor, errorPayload, cause, clientHeaderAccessor); } + + /** + * Tries to match method on {@link #webSocketExceptionHandler} + * + * @return true when a method was found and called, false otherwise + * @throws IllegalArgumentException never + */ + private boolean delegate(Message message, Throwable throwable) + throws InvocationTargetException, IllegalAccessException { + if (throwable instanceof Exception exception) { + Method[] methods = webSocketExceptionHandler.getClass().getMethods(); + for (final Method method : methods) { + if (!method.canAccess(webSocketExceptionHandler)) { + continue; + } + Class[] params = method.getParameterTypes(); + if (params.length != 2) { + continue; + } + if (params[0].isAssignableFrom(message.getClass()) && params[1].isAssignableFrom(exception.getClass())) { + // message, exception + method.invoke(webSocketExceptionHandler, message, exception); + return true; + } else if (params[0].isAssignableFrom(exception.getClass()) && params[1].isAssignableFrom(message.getClass())) { + // exception, message + method.invoke(webSocketExceptionHandler, exception, message); + return true; + } + } + } + return false; + } } diff --git a/src/main/java/cz/cvut/kbss/termit/websocket/handler/WebSocketExceptionHandler.java b/src/main/java/cz/cvut/kbss/termit/websocket/handler/WebSocketExceptionHandler.java index 50411c650..e94b99450 100644 --- a/src/main/java/cz/cvut/kbss/termit/websocket/handler/WebSocketExceptionHandler.java +++ b/src/main/java/cz/cvut/kbss/termit/websocket/handler/WebSocketExceptionHandler.java @@ -30,6 +30,8 @@ import org.springframework.messaging.handler.annotation.MessageExceptionHandler; import org.springframework.messaging.simp.annotation.SendToUser; import org.springframework.messaging.simp.stomp.StompHeaderAccessor; +import org.springframework.security.access.AccessDeniedException; +import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.userdetails.UsernameNotFoundException; import org.springframework.web.bind.annotation.ControllerAdvice; import org.springframework.web.multipart.MaxUploadSizeExceededException; @@ -44,7 +46,12 @@ public class WebSocketExceptionHandler { private static final Logger LOG = LoggerFactory.getLogger(WebSocketExceptionHandler.class); private static String destination(Message message) { - return message.getHeaders().getOrDefault("destination", "missing destination").toString(); + return message.getHeaders().getOrDefault("destination", "(missing destination)").toString(); + } + + private static boolean hasDestination(Message message) { + final String dst = (String) message.getHeaders().getOrDefault("destination", ""); + return dst != null && !dst.isBlank(); } private static void logException(TermItException ex, Message message) { @@ -72,8 +79,13 @@ private static ErrorInfo errorInfo(Message message, Throwable e) { @MessageExceptionHandler public void messageDeliveryException(Message message, MessageDeliveryException e) { - final StompHeaderAccessor headerAccessor = StompHeaderAccessor.wrap(message); - LOG.error("Failed to send message with destination {}: {}", headerAccessor.getDestination(), e.getMessage()); + // messages without destination will be logged only on trace + (hasDestination(message) ? LOG.atError() : LOG.atTrace()) + .setMessage("Failed to send message with destination {}: {}") + .addArgument(()-> destination(message)) + .addArgument(e.getMessage()) + .setCause(e.getCause()) + .log(); } @MessageExceptionHandler(PersistenceException.class) @@ -117,6 +129,28 @@ public ErrorInfo authorizationException(Message message, AuthorizationExcepti return errorInfo(message, e); } + @MessageExceptionHandler(AuthenticationException.class) + public ErrorInfo authenticationException(Message message, AuthenticationException e) { + LOG.atDebug().setCause(e).log(e.getMessage()); + LOG.error("Authentication failure during message processing: {}\nMessage: {}", e.getMessage(), message.toString()); + return errorInfo(message, e); + } + + /** + * Fired, for example, on method security violation + */ + @MessageExceptionHandler(AccessDeniedException.class) + public ErrorInfo accessDeniedException(Message message, AccessDeniedException e) { + LOG.atWarn().setMessage("[{}] Unauthorized access: {}").addArgument(() -> { + StompHeaderAccessor accessor = StompHeaderAccessor.wrap(message); + if (accessor.getUser() != null) { + return accessor.getUser().getName(); + } + return "(unknown user)"; + }).addArgument(e.getMessage()).log(); + return errorInfo(message, e); + } + @MessageExceptionHandler(ValidationException.class) public ErrorInfo validationException(Message message, ValidationException e) { logException(e, message); diff --git a/src/test/java/cz/cvut/kbss/termit/environment/config/TestWebSocketConfig.java b/src/test/java/cz/cvut/kbss/termit/environment/config/TestWebSocketConfig.java index f3cd62e7d..1ce9b63fd 100644 --- a/src/test/java/cz/cvut/kbss/termit/environment/config/TestWebSocketConfig.java +++ b/src/test/java/cz/cvut/kbss/termit/environment/config/TestWebSocketConfig.java @@ -2,6 +2,8 @@ import cz.cvut.kbss.termit.config.WebAppConfig; import cz.cvut.kbss.termit.config.WebSocketConfig; +import cz.cvut.kbss.termit.config.WebSocketMessageBrokerConfig; +import cz.cvut.kbss.termit.security.WebSocketJwtAuthorizationInterceptor; import cz.cvut.kbss.termit.util.Configuration; import cz.cvut.kbss.termit.websocket.util.ReturnValueCollectingSimpMessagingTemplate; import org.jetbrains.annotations.NotNull; @@ -24,6 +26,8 @@ import org.springframework.messaging.simp.config.ChannelRegistration; import org.springframework.messaging.simp.config.MessageBrokerRegistry; import org.springframework.messaging.support.AbstractSubscribableChannel; +import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationProvider; +import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker; import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer; import java.util.HashMap; @@ -32,9 +36,10 @@ import java.util.UUID; @TestConfiguration +@EnableWebSocketMessageBroker @EnableConfigurationProperties(Configuration.class) -@Import({TestSecurityConfig.class, TestRestSecurityConfig.class, WebAppConfig.class, WebSocketConfig.class}) -@ComponentScan(basePackages = "cz.cvut.kbss.termit.websocket") +@Import({TestSecurityConfig.class, TestRestSecurityConfig.class, WebAppConfig.class, WebSocketConfig.class, WebSocketMessageBrokerConfig.class}) +@ComponentScan(basePackages = {"cz.cvut.kbss.termit.websocket", "cz.cvut.kbss.termit.websocket.handler"}) public class TestWebSocketConfig implements ApplicationListener, WebSocketMessageBrokerConfigurer { @@ -95,4 +100,9 @@ public SimpMessagingTemplate brokerMessagingTemplate( template.setMessageConverter(brokerMessageConverter); return template; } + + @Bean + public WebSocketJwtAuthorizationInterceptor webSocketJwtAuthorizationInterceptor(JwtAuthenticationProvider jwtAuthenticationProvider) { + return new WebSocketJwtAuthorizationInterceptor(jwtAuthenticationProvider); + } } diff --git a/src/test/java/cz/cvut/kbss/termit/security/JwtAuthenticationFilterTest.java b/src/test/java/cz/cvut/kbss/termit/security/JwtAuthenticationFilterTest.java index 1294826fa..61e494d77 100644 --- a/src/test/java/cz/cvut/kbss/termit/security/JwtAuthenticationFilterTest.java +++ b/src/test/java/cz/cvut/kbss/termit/security/JwtAuthenticationFilterTest.java @@ -28,6 +28,7 @@ import io.jsonwebtoken.Jws; import io.jsonwebtoken.Jwts; import io.jsonwebtoken.security.Keys; +import jakarta.servlet.FilterChain; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Tag; import org.junit.jupiter.api.Test; @@ -43,11 +44,12 @@ import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit.jupiter.SpringExtension; -import jakarta.servlet.FilterChain; import java.nio.charset.StandardCharsets; import java.util.Collections; -import static org.junit.jupiter.api.Assertions.*; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.Mockito.mock; @Tag("security") diff --git a/src/test/java/cz/cvut/kbss/termit/security/JwtAuthorizationFilterTest.java b/src/test/java/cz/cvut/kbss/termit/security/JwtAuthorizationFilterTest.java index e273fa08a..49813826d 100644 --- a/src/test/java/cz/cvut/kbss/termit/security/JwtAuthorizationFilterTest.java +++ b/src/test/java/cz/cvut/kbss/termit/security/JwtAuthorizationFilterTest.java @@ -25,11 +25,11 @@ import cz.cvut.kbss.termit.rest.ConfigurationController; import cz.cvut.kbss.termit.rest.handler.ErrorInfo; import cz.cvut.kbss.termit.security.model.TermItUserDetails; -import cz.cvut.kbss.termit.service.security.SecurityUtils; import cz.cvut.kbss.termit.service.security.TermItUserDetailsService; import cz.cvut.kbss.termit.util.Configuration; import io.jsonwebtoken.Jwts; import io.jsonwebtoken.security.Keys; +import jakarta.servlet.FilterChain; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Tag; @@ -49,7 +49,6 @@ import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit.jupiter.SpringExtension; -import jakarta.servlet.FilterChain; import java.nio.charset.StandardCharsets; import java.security.Key; import java.time.Instant; @@ -59,8 +58,17 @@ import static cz.cvut.kbss.termit.util.Constants.REST_MAPPING_PATH; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.containsString; -import static org.junit.jupiter.api.Assertions.*; -import static org.mockito.Mockito.*; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.anyString; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; @Tag("security") @ExtendWith({SpringExtension.class, MockitoExtension.class}) @@ -93,6 +101,8 @@ class JwtAuthorizationFilterTest { private JwtAuthorizationFilter sut; + private TermitJwtDecoder termitJwtDecoder; + private final Instant tokenIssued = JwtUtils.issueTimestamp(); @BeforeEach @@ -101,8 +111,8 @@ void setUp() { this.objectMapper = Environment.getObjectMapper(); this.signingKey = Keys.hmacShaKeyFor(config.getJwt().getSecretKey().getBytes(StandardCharsets.UTF_8)); this.jwtUtilsSpy = spy(new JwtUtils(objectMapper, config)); - this.sut = new JwtAuthorizationFilter(authManagerMock, jwtUtilsSpy, detailsServiceMock, - objectMapper); + this.termitJwtDecoder = new TermitJwtDecoder(jwtUtilsSpy, detailsServiceMock); + this.sut = new JwtAuthorizationFilter(authManagerMock, jwtUtilsSpy, objectMapper, termitJwtDecoder); } @AfterEach diff --git a/src/test/java/cz/cvut/kbss/termit/websocket/BaseWebSocketControllerTestRunner.java b/src/test/java/cz/cvut/kbss/termit/websocket/BaseWebSocketControllerTestRunner.java index 4c4c944d8..da6684097 100644 --- a/src/test/java/cz/cvut/kbss/termit/websocket/BaseWebSocketControllerTestRunner.java +++ b/src/test/java/cz/cvut/kbss/termit/websocket/BaseWebSocketControllerTestRunner.java @@ -1,5 +1,6 @@ package cz.cvut.kbss.termit.websocket; +import cz.cvut.kbss.termit.environment.config.TestRestSecurityConfig; import cz.cvut.kbss.termit.environment.config.TestWebSocketConfig; import cz.cvut.kbss.termit.util.Configuration; import cz.cvut.kbss.termit.websocket.util.CachingChannelInterceptor; @@ -31,7 +32,7 @@ @ExtendWith(MockitoExtension.class) @EnableConfigurationProperties({Configuration.class}) @DirtiesContext(classMode = DirtiesContext.ClassMode.AFTER_CLASS) -@ContextConfiguration(classes = {TestWebSocketConfig.class}, +@ContextConfiguration(classes = {TestRestSecurityConfig.class, TestWebSocketConfig.class}, initializers = {ConfigDataApplicationContextInitializer.class}) public abstract class BaseWebSocketControllerTestRunner { diff --git a/src/test/java/cz/cvut/kbss/termit/websocket/BaseWebSocketIntegrationTestRunner.java b/src/test/java/cz/cvut/kbss/termit/websocket/BaseWebSocketIntegrationTestRunner.java index 59ba7d502..ce30493d2 100644 --- a/src/test/java/cz/cvut/kbss/termit/websocket/BaseWebSocketIntegrationTestRunner.java +++ b/src/test/java/cz/cvut/kbss/termit/websocket/BaseWebSocketIntegrationTestRunner.java @@ -4,10 +4,10 @@ import cz.cvut.kbss.termit.config.SecurityConfig; import cz.cvut.kbss.termit.config.WebAppConfig; import cz.cvut.kbss.termit.config.WebSocketConfig; +import cz.cvut.kbss.termit.config.WebSocketMessageBrokerConfig; import cz.cvut.kbss.termit.environment.Generator; import cz.cvut.kbss.termit.environment.config.TestConfig; import cz.cvut.kbss.termit.environment.config.TestPersistenceConfig; -import cz.cvut.kbss.termit.environment.config.TestSecurityConfig; import cz.cvut.kbss.termit.environment.config.TestServiceConfig; import cz.cvut.kbss.termit.security.JwtUtils; import cz.cvut.kbss.termit.security.model.TermItUserDetails; @@ -16,19 +16,14 @@ import org.jetbrains.annotations.NotNull; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.extension.ExtendWith; -import org.mockito.Answers; -import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; -import org.springframework.boot.autoconfigure.AutoConfigureOrder; import org.springframework.boot.autoconfigure.EnableAutoConfiguration; import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.boot.test.context.ConfigDataApplicationContextInitializer; import org.springframework.boot.test.context.SpringBootTest; -import org.springframework.boot.test.mock.mockito.MockBean; import org.springframework.boot.test.mock.mockito.SpyBean; import org.springframework.context.annotation.ComponentScan; import org.springframework.context.annotation.EnableAspectJAutoProxy; @@ -37,7 +32,6 @@ import org.springframework.messaging.simp.stomp.StompHeaders; import org.springframework.messaging.simp.stomp.StompSession; import org.springframework.messaging.simp.stomp.StompSessionHandlerAdapter; -import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; import org.springframework.test.annotation.DirtiesContext; import org.springframework.test.context.ActiveProfiles; import org.springframework.test.context.ContextConfiguration; @@ -53,7 +47,6 @@ import java.util.concurrent.atomic.AtomicReference; import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.when; @ActiveProfiles("test") @EnableSpringConfigured @@ -63,10 +56,11 @@ @EnableAspectJAutoProxy(proxyTargetClass = true) @EnableConfigurationProperties({Configuration.class}) @ContextConfiguration( - classes = {TestConfig.class, TestPersistenceConfig.class, TestConfig.class, - TestServiceConfig.class, AppConfig.class, SecurityConfig.class, WebAppConfig.class, WebSocketConfig.class}, + classes = {TestConfig.class, TestPersistenceConfig.class, TestServiceConfig.class, AppConfig.class, + SecurityConfig.class, WebAppConfig.class, WebSocketConfig.class, WebSocketMessageBrokerConfig.class}, initializers = {ConfigDataApplicationContextInitializer.class}) -@ComponentScan("cz.cvut.kbss.termit.security") +@ComponentScan( + {"cz.cvut.kbss.termit.security", "cz.cvut.kbss.termit.websocket", "cz.cvut.kbss.termit.websocket.handler"}) @DirtiesContext(classMode = DirtiesContext.ClassMode.AFTER_CLASS) @SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT) public abstract class BaseWebSocketIntegrationTestRunner {