Skip to content

Commit

Permalink
Merge pull request #113 from Nexters/feature/112-socket-settings
Browse files Browse the repository at this point in the history
소켓 설정 수정
  • Loading branch information
emost22 authored Feb 24, 2024
2 parents 5aa53aa + 4b7f6d4 commit 8266e78
Show file tree
Hide file tree
Showing 8 changed files with 89 additions and 85 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public class AuthorizationFilter extends OncePerRequestFilter {
private static final List<RequestMatcher> whiteList =
List.of(
new AntPathRequestMatcher("/oauth2/**", HttpMethod.POST.name()),
new AntPathRequestMatcher("/course", HttpMethod.GET.name()));
new AntPathRequestMatcher("/course/**", HttpMethod.GET.name()));

@Override
protected void doFilterInternal(
Expand Down
12 changes: 12 additions & 0 deletions src/main/java/com/pcb/audy/global/redis/RedisProvider.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import java.time.Duration;
import java.util.List;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import lombok.RequiredArgsConstructor;
import org.springframework.data.redis.connection.RedisStringCommands;
import org.springframework.data.redis.core.RedisCallback;
Expand Down Expand Up @@ -67,6 +68,17 @@ public void multiSet(List<PinRedisRes> pinSaveResList) {
});
}

public void setValues(String key, Object o, long expireTime) {
Long len = redisTemplate.opsForList().size(key);
redisTemplate.setValueSerializer(new Jackson2JsonRedisSerializer<>(o.getClass()));
redisTemplate.opsForList().remove(key, 1L, o);
redisTemplate.opsForList().rightPush(key, o);

if (len == 0) {
redisTemplate.expire(key, expireTime, TimeUnit.MILLISECONDS);
}
}

public void delete(String key) {
redisTemplate.delete(key);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public SecurityFilterChain securityFilterChain(HttpSecurity http) throws Excepti
authorizationHttpRequests
.requestMatchers("/oauth2/**")
.permitAll()
.requestMatchers("/course")
.requestMatchers("/course/**")
.permitAll()
.anyRequest()
.authenticated())
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
package com.pcb.audy.global.socket.config;

import com.pcb.audy.global.jwt.JwtUtils;
import com.pcb.audy.global.redis.RedisProvider;
import com.pcb.audy.global.socket.handler.CustomHandshakeInterceptor;
import com.pcb.audy.global.socket.handler.SocketErrorHandler;
import com.pcb.audy.global.socket.handler.SocketHandler;
import lombok.RequiredArgsConstructor;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.messaging.simp.config.ChannelRegistration;
import org.springframework.messaging.simp.config.MessageBrokerRegistry;
import org.springframework.web.socket.config.annotation.*;
import org.springframework.web.socket.server.HandshakeInterceptor;
Expand All @@ -16,13 +16,14 @@
@RequiredArgsConstructor
public class WebSocketConfig implements WebSocketMessageBrokerConfigurer {

private final SocketHandler socketHandler;
private final SocketErrorHandler socketErrorHandler;
private final RedisProvider redisProvider;
private final JwtUtils jwtUtils;

@Override
public void registerStompEndpoints(StompEndpointRegistry registry) {
registry
.addEndpoint("/course")
.addEndpoint("/course/{courseId}")
.setAllowedOriginPatterns("*")
.addInterceptors(customHandshakeInterceptor())
.withSockJS();
Expand All @@ -35,13 +36,8 @@ public void configureMessageBroker(MessageBrokerRegistry registry) {
registry.setApplicationDestinationPrefixes("/pub"); // 클라이언트에서 보낸 메시지 받기
}

@Override
public void configureClientInboundChannel(ChannelRegistration registration) {
registration.interceptors(socketHandler);
}

@Bean
public HandshakeInterceptor customHandshakeInterceptor() {
return new CustomHandshakeInterceptor();
return new CustomHandshakeInterceptor(redisProvider, jwtUtils);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,15 @@

import static com.pcb.audy.global.jwt.JwtUtils.ACCESS_TOKEN_NAME;
import static com.pcb.audy.global.jwt.JwtUtils.REFRESH_TOKEN_NAME;
import static com.pcb.audy.global.jwt.JwtUtils.TOKEN_TYPE;

import com.pcb.audy.global.jwt.JwtUtils;
import com.pcb.audy.global.redis.RedisProvider;
import com.pcb.audy.global.validator.TokenValidator;
import jakarta.servlet.http.Cookie;
import jakarta.servlet.http.HttpServletRequest;
import java.util.Map;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
Expand All @@ -15,7 +20,11 @@
import org.springframework.web.util.WebUtils;

@Slf4j
@RequiredArgsConstructor
public class CustomHandshakeInterceptor implements HandshakeInterceptor {
private final RedisProvider redisProvider;
private final JwtUtils jwtUtils;
private final String SOCKET_PREFIX = "socket:";

@Override
public boolean beforeHandshake(
Expand All @@ -25,14 +34,10 @@ public boolean beforeHandshake(
Map<String, Object> attributes) {
if (request instanceof ServletServerHttpRequest servletServerRequest) {
HttpServletRequest servletRequest = servletServerRequest.getServletRequest();
Cookie accessCookie = WebUtils.getCookie(servletRequest, ACCESS_TOKEN_NAME);
if (accessCookie != null) {
attributes.put(ACCESS_TOKEN_NAME, accessCookie.getValue());
}

Cookie refreshCookie = WebUtils.getCookie(servletRequest, REFRESH_TOKEN_NAME);
if (refreshCookie != null) {
attributes.put(REFRESH_TOKEN_NAME, refreshCookie.getValue());
String email = getEmail(servletRequest, ACCESS_TOKEN_NAME);
if (email == null) {
email = getEmail(servletRequest, REFRESH_TOKEN_NAME);
TokenValidator.validateEmail(email);
}
}
return true;
Expand All @@ -43,5 +48,23 @@ public void afterHandshake(
ServerHttpRequest request,
ServerHttpResponse response,
WebSocketHandler wsHandler,
Exception exception) {}
Exception exception) {
if (request instanceof ServletServerHttpRequest servletServerRequest) {
HttpServletRequest servletRequest = servletServerRequest.getServletRequest();
if (servletRequest.getRequestURI() != null) {
String email = getEmail(servletRequest, REFRESH_TOKEN_NAME);
String courseId = servletRequest.getRequestURI().replace("/course/", "");
redisProvider.setValues(SOCKET_PREFIX + courseId, email, Integer.MAX_VALUE);
}
}
}

private String getEmail(HttpServletRequest request, String tokenName) {
Cookie cookie = WebUtils.getCookie(request, tokenName);
TokenValidator.validate(cookie);

String token = cookie.getValue().replace(TOKEN_TYPE, "");
log.info(tokenName + " in socket: " + token);
return jwtUtils.getEmail(token);
}
}

This file was deleted.

15 changes: 13 additions & 2 deletions src/main/java/com/pcb/audy/global/validator/TokenValidator.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,17 @@
import static com.pcb.audy.global.response.ResultCode.INVALID_TOKEN;

import com.pcb.audy.global.exception.GlobalException;
import jakarta.servlet.http.Cookie;

public class TokenValidator {
public static void validate(String cookie) {
if (!isExistCookie(cookie)) {
if (!isValidCookie(cookie)) {
throw new GlobalException(INVALID_TOKEN);
}
}

public static void validate(Cookie cookie) {
if (!isValidCookie(cookie)) {
throw new GlobalException(INVALID_TOKEN);
}
}
Expand All @@ -22,7 +29,11 @@ private static boolean isExistEmail(String email) {
return email != null;
}

private static boolean isExistCookie(String cookie) {
private static boolean isValidCookie(Cookie cookie) {
return cookie != null && cookie.getValue().startsWith(TOKEN_TYPE);
}

private static boolean isValidCookie(String cookie) {
return cookie != null && cookie.startsWith(TOKEN_TYPE);
}
}
24 changes: 24 additions & 0 deletions src/test/java/com/pcb/audy/global/redis/RedisProviderTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import static java.lang.Boolean.TRUE;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

Expand All @@ -14,6 +16,7 @@
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.data.redis.core.ListOperations;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.ValueOperations;

Expand All @@ -23,6 +26,7 @@ class RedisProviderTest implements RedisTest {

@Mock private RedisTemplate<String, Object> redisTemplate;
@Mock private ValueOperations<String, Object> valueOperations;
@Mock private ListOperations<String, Object> listOperations;

@Test
@DisplayName("데이터 저장 테스트")
Expand Down Expand Up @@ -82,4 +86,24 @@ class RedisProviderTest implements RedisTest {
verify(redisTemplate).hasKey(any());
assertThat(result).isEqualTo(TRUE);
}

// TODO add save list
@Test
@DisplayName("데이터 list에 저장 테스트")
void 데이터_list_저장() {
// given
when(redisTemplate.opsForList()).thenReturn(listOperations);
when(listOperations.size(any())).thenReturn(0L);
when(listOperations.remove(any(), anyLong(), any())).thenReturn(1L);

// when
redisProvider.setValues(TEST_KEY, TEST_VALUE, TEST_EXPIRE_TIME);

// then
verify(redisTemplate, times(3)).opsForList();
verify(listOperations).size(any());
verify(listOperations).remove(any(), anyLong(), any());
verify(listOperations).rightPush(any(), any());
verify(redisTemplate).expire(any(), anyLong(), any());
}
}

0 comments on commit 8266e78

Please sign in to comment.