diff --git a/src/main/java/com/pcb/audy/global/jwt/AuthorizationFilter.java b/src/main/java/com/pcb/audy/global/jwt/AuthorizationFilter.java index 615d995..9e9ab34 100644 --- a/src/main/java/com/pcb/audy/global/jwt/AuthorizationFilter.java +++ b/src/main/java/com/pcb/audy/global/jwt/AuthorizationFilter.java @@ -41,7 +41,7 @@ public class AuthorizationFilter extends OncePerRequestFilter { private static final List whiteList = List.of( new AntPathRequestMatcher("/oauth2/**", HttpMethod.POST.name()), - new AntPathRequestMatcher("/course", HttpMethod.GET.name())); + new AntPathRequestMatcher("/course/**", HttpMethod.GET.name())); @Override protected void doFilterInternal( diff --git a/src/main/java/com/pcb/audy/global/redis/RedisProvider.java b/src/main/java/com/pcb/audy/global/redis/RedisProvider.java index 2fa2f86..811effc 100644 --- a/src/main/java/com/pcb/audy/global/redis/RedisProvider.java +++ b/src/main/java/com/pcb/audy/global/redis/RedisProvider.java @@ -4,6 +4,7 @@ import java.time.Duration; import java.util.List; import java.util.Set; +import java.util.concurrent.TimeUnit; import lombok.RequiredArgsConstructor; import org.springframework.data.redis.connection.RedisStringCommands; import org.springframework.data.redis.core.RedisCallback; @@ -67,6 +68,17 @@ public void multiSet(List pinSaveResList) { }); } + public void setValues(String key, Object o, long expireTime) { + Long len = redisTemplate.opsForList().size(key); + redisTemplate.setValueSerializer(new Jackson2JsonRedisSerializer<>(o.getClass())); + redisTemplate.opsForList().remove(key, 1L, o); + redisTemplate.opsForList().rightPush(key, o); + + if (len == 0) { + redisTemplate.expire(key, expireTime, TimeUnit.MILLISECONDS); + } + } + public void delete(String key) { redisTemplate.delete(key); } diff --git a/src/main/java/com/pcb/audy/global/security/SecurityConfig.java b/src/main/java/com/pcb/audy/global/security/SecurityConfig.java index 95861b1..00e613b 100644 --- a/src/main/java/com/pcb/audy/global/security/SecurityConfig.java +++ b/src/main/java/com/pcb/audy/global/security/SecurityConfig.java @@ -48,7 +48,7 @@ public SecurityFilterChain securityFilterChain(HttpSecurity http) throws Excepti authorizationHttpRequests .requestMatchers("/oauth2/**") .permitAll() - .requestMatchers("/course") + .requestMatchers("/course/**") .permitAll() .anyRequest() .authenticated()) diff --git a/src/main/java/com/pcb/audy/global/socket/config/WebSocketConfig.java b/src/main/java/com/pcb/audy/global/socket/config/WebSocketConfig.java index eb9a56b..71c325f 100644 --- a/src/main/java/com/pcb/audy/global/socket/config/WebSocketConfig.java +++ b/src/main/java/com/pcb/audy/global/socket/config/WebSocketConfig.java @@ -1,12 +1,12 @@ package com.pcb.audy.global.socket.config; +import com.pcb.audy.global.jwt.JwtUtils; +import com.pcb.audy.global.redis.RedisProvider; import com.pcb.audy.global.socket.handler.CustomHandshakeInterceptor; import com.pcb.audy.global.socket.handler.SocketErrorHandler; -import com.pcb.audy.global.socket.handler.SocketHandler; import lombok.RequiredArgsConstructor; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; -import org.springframework.messaging.simp.config.ChannelRegistration; import org.springframework.messaging.simp.config.MessageBrokerRegistry; import org.springframework.web.socket.config.annotation.*; import org.springframework.web.socket.server.HandshakeInterceptor; @@ -16,13 +16,14 @@ @RequiredArgsConstructor public class WebSocketConfig implements WebSocketMessageBrokerConfigurer { - private final SocketHandler socketHandler; private final SocketErrorHandler socketErrorHandler; + private final RedisProvider redisProvider; + private final JwtUtils jwtUtils; @Override public void registerStompEndpoints(StompEndpointRegistry registry) { registry - .addEndpoint("/course") + .addEndpoint("/course/{courseId}") .setAllowedOriginPatterns("*") .addInterceptors(customHandshakeInterceptor()) .withSockJS(); @@ -35,13 +36,8 @@ public void configureMessageBroker(MessageBrokerRegistry registry) { registry.setApplicationDestinationPrefixes("/pub"); // 클라이언트에서 보낸 메시지 받기 } - @Override - public void configureClientInboundChannel(ChannelRegistration registration) { - registration.interceptors(socketHandler); - } - @Bean public HandshakeInterceptor customHandshakeInterceptor() { - return new CustomHandshakeInterceptor(); + return new CustomHandshakeInterceptor(redisProvider, jwtUtils); } } diff --git a/src/main/java/com/pcb/audy/global/socket/handler/CustomHandshakeInterceptor.java b/src/main/java/com/pcb/audy/global/socket/handler/CustomHandshakeInterceptor.java index fed536b..7c74872 100644 --- a/src/main/java/com/pcb/audy/global/socket/handler/CustomHandshakeInterceptor.java +++ b/src/main/java/com/pcb/audy/global/socket/handler/CustomHandshakeInterceptor.java @@ -2,10 +2,15 @@ import static com.pcb.audy.global.jwt.JwtUtils.ACCESS_TOKEN_NAME; import static com.pcb.audy.global.jwt.JwtUtils.REFRESH_TOKEN_NAME; +import static com.pcb.audy.global.jwt.JwtUtils.TOKEN_TYPE; +import com.pcb.audy.global.jwt.JwtUtils; +import com.pcb.audy.global.redis.RedisProvider; +import com.pcb.audy.global.validator.TokenValidator; import jakarta.servlet.http.Cookie; import jakarta.servlet.http.HttpServletRequest; import java.util.Map; +import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; @@ -15,7 +20,11 @@ import org.springframework.web.util.WebUtils; @Slf4j +@RequiredArgsConstructor public class CustomHandshakeInterceptor implements HandshakeInterceptor { + private final RedisProvider redisProvider; + private final JwtUtils jwtUtils; + private final String SOCKET_PREFIX = "socket:"; @Override public boolean beforeHandshake( @@ -25,14 +34,10 @@ public boolean beforeHandshake( Map attributes) { if (request instanceof ServletServerHttpRequest servletServerRequest) { HttpServletRequest servletRequest = servletServerRequest.getServletRequest(); - Cookie accessCookie = WebUtils.getCookie(servletRequest, ACCESS_TOKEN_NAME); - if (accessCookie != null) { - attributes.put(ACCESS_TOKEN_NAME, accessCookie.getValue()); - } - - Cookie refreshCookie = WebUtils.getCookie(servletRequest, REFRESH_TOKEN_NAME); - if (refreshCookie != null) { - attributes.put(REFRESH_TOKEN_NAME, refreshCookie.getValue()); + String email = getEmail(servletRequest, ACCESS_TOKEN_NAME); + if (email == null) { + email = getEmail(servletRequest, REFRESH_TOKEN_NAME); + TokenValidator.validateEmail(email); } } return true; @@ -43,5 +48,23 @@ public void afterHandshake( ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, - Exception exception) {} + Exception exception) { + if (request instanceof ServletServerHttpRequest servletServerRequest) { + HttpServletRequest servletRequest = servletServerRequest.getServletRequest(); + if (servletRequest.getRequestURI() != null) { + String email = getEmail(servletRequest, REFRESH_TOKEN_NAME); + String courseId = servletRequest.getRequestURI().replace("/course/", ""); + redisProvider.setValues(SOCKET_PREFIX + courseId, email, Integer.MAX_VALUE); + } + } + } + + private String getEmail(HttpServletRequest request, String tokenName) { + Cookie cookie = WebUtils.getCookie(request, tokenName); + TokenValidator.validate(cookie); + + String token = cookie.getValue().replace(TOKEN_TYPE, ""); + log.info(tokenName + " in socket: " + token); + return jwtUtils.getEmail(token); + } } diff --git a/src/main/java/com/pcb/audy/global/socket/handler/SocketHandler.java b/src/main/java/com/pcb/audy/global/socket/handler/SocketHandler.java deleted file mode 100644 index bd50918..0000000 --- a/src/main/java/com/pcb/audy/global/socket/handler/SocketHandler.java +++ /dev/null @@ -1,62 +0,0 @@ -package com.pcb.audy.global.socket.handler; - -import static com.pcb.audy.global.jwt.JwtUtils.ACCESS_TOKEN_NAME; -import static com.pcb.audy.global.jwt.JwtUtils.REFRESH_TOKEN_NAME; -import static com.pcb.audy.global.jwt.JwtUtils.TOKEN_TYPE; -import static org.springframework.messaging.simp.stomp.StompCommand.CONNECT; -import static org.springframework.messaging.simp.stomp.StompCommand.SEND; - -import com.fasterxml.jackson.databind.ObjectMapper; -import com.pcb.audy.global.jwt.JwtUtils; -import com.pcb.audy.global.validator.TokenValidator; -import java.util.Map; -import lombok.RequiredArgsConstructor; -import lombok.extern.slf4j.Slf4j; -import org.springframework.messaging.Message; -import org.springframework.messaging.MessageChannel; -import org.springframework.messaging.simp.stomp.StompCommand; -import org.springframework.messaging.simp.stomp.StompHeaderAccessor; -import org.springframework.messaging.support.ChannelInterceptor; -import org.springframework.stereotype.Component; -import org.springframework.util.CollectionUtils; - -@Slf4j -@Component -@RequiredArgsConstructor -public class SocketHandler implements ChannelInterceptor { - - private final JwtUtils jwtUtils; - private final ObjectMapper objectMapper; - - @Override - public Message preSend(Message message, MessageChannel channel) { - StompHeaderAccessor accessor = StompHeaderAccessor.wrap(message); - if (isNeedTokenValidate(accessor.getCommand())) { - Map sessionAttributes = accessor.getSessionAttributes(); - if (!CollectionUtils.isEmpty(sessionAttributes)) { - String accessCookie = - objectMapper.convertValue(sessionAttributes.get(ACCESS_TOKEN_NAME), String.class); - log.info("accessCookie in socket: " + accessCookie); - TokenValidator.validate(accessCookie); - - String accessToken = accessCookie.replace(TOKEN_TYPE, ""); - if (jwtUtils.getEmail(accessToken) == null) { - String refreshCookie = - objectMapper.convertValue(sessionAttributes.get(REFRESH_TOKEN_NAME), String.class); - log.info("refreshCookie in socket: " + refreshCookie); - TokenValidator.validate(refreshCookie); - - String refreshToken = refreshCookie.replace(TOKEN_TYPE, ""); - String email = jwtUtils.getEmail(refreshToken); - TokenValidator.validateEmail(email); - } - } - } - - return message; - } - - private boolean isNeedTokenValidate(StompCommand stompCommand) { - return CONNECT.equals(stompCommand) || SEND.equals(stompCommand); - } -} diff --git a/src/main/java/com/pcb/audy/global/validator/TokenValidator.java b/src/main/java/com/pcb/audy/global/validator/TokenValidator.java index 507d4bd..ab43d52 100644 --- a/src/main/java/com/pcb/audy/global/validator/TokenValidator.java +++ b/src/main/java/com/pcb/audy/global/validator/TokenValidator.java @@ -4,10 +4,17 @@ import static com.pcb.audy.global.response.ResultCode.INVALID_TOKEN; import com.pcb.audy.global.exception.GlobalException; +import jakarta.servlet.http.Cookie; public class TokenValidator { public static void validate(String cookie) { - if (!isExistCookie(cookie)) { + if (!isValidCookie(cookie)) { + throw new GlobalException(INVALID_TOKEN); + } + } + + public static void validate(Cookie cookie) { + if (!isValidCookie(cookie)) { throw new GlobalException(INVALID_TOKEN); } } @@ -22,7 +29,11 @@ private static boolean isExistEmail(String email) { return email != null; } - private static boolean isExistCookie(String cookie) { + private static boolean isValidCookie(Cookie cookie) { + return cookie != null && cookie.getValue().startsWith(TOKEN_TYPE); + } + + private static boolean isValidCookie(String cookie) { return cookie != null && cookie.startsWith(TOKEN_TYPE); } } diff --git a/src/test/java/com/pcb/audy/global/redis/RedisProviderTest.java b/src/test/java/com/pcb/audy/global/redis/RedisProviderTest.java index 99c1cf1..dd86cac 100644 --- a/src/test/java/com/pcb/audy/global/redis/RedisProviderTest.java +++ b/src/test/java/com/pcb/audy/global/redis/RedisProviderTest.java @@ -3,7 +3,9 @@ import static java.lang.Boolean.TRUE; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -14,6 +16,7 @@ import org.mockito.InjectMocks; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.data.redis.core.ListOperations; import org.springframework.data.redis.core.RedisTemplate; import org.springframework.data.redis.core.ValueOperations; @@ -23,6 +26,7 @@ class RedisProviderTest implements RedisTest { @Mock private RedisTemplate redisTemplate; @Mock private ValueOperations valueOperations; + @Mock private ListOperations listOperations; @Test @DisplayName("데이터 저장 테스트") @@ -82,4 +86,24 @@ class RedisProviderTest implements RedisTest { verify(redisTemplate).hasKey(any()); assertThat(result).isEqualTo(TRUE); } + + // TODO add save list + @Test + @DisplayName("데이터 list에 저장 테스트") + void 데이터_list_저장() { + // given + when(redisTemplate.opsForList()).thenReturn(listOperations); + when(listOperations.size(any())).thenReturn(0L); + when(listOperations.remove(any(), anyLong(), any())).thenReturn(1L); + + // when + redisProvider.setValues(TEST_KEY, TEST_VALUE, TEST_EXPIRE_TIME); + + // then + verify(redisTemplate, times(3)).opsForList(); + verify(listOperations).size(any()); + verify(listOperations).remove(any(), anyLong(), any()); + verify(listOperations).rightPush(any(), any()); + verify(redisTemplate).expire(any(), anyLong(), any()); + } }