Skip to content

Commit

Permalink
- Check expired sessionInThread in RedisSessionDAO
Browse files Browse the repository at this point in the history
- Add testing sessionInThread in integrationTest
  • Loading branch information
alexy committed Sep 7, 2020
1 parent ff7e1e5 commit e36d35e
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 28 deletions.
2 changes: 0 additions & 2 deletions docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ You use either of the following 2 ways to include `shiro-redis` into your projec

> **Note:**\
> Do not use version < 3.1.0\
> **注意**\
> 请不要使用3.1.0以下版本
## shiro-core/jedis Version Comparison Charts

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import org.crazycake.shiro.common.SessionInMemory;
import org.crazycake.shiro.exception.SerializationException;
import org.crazycake.shiro.integration.fixture.model.FakeSession;
import org.crazycake.shiro.serializer.ObjectSerializer;
import org.crazycake.shiro.serializer.StringSerializer;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Assertions;
Expand All @@ -16,6 +17,8 @@
import java.util.Map;

import static org.crazycake.shiro.integration.fixture.TestFixture.*;
import static org.hamcrest.CoreMatchers.*;
import static org.hamcrest.MatcherAssert.assertThat;

/**
* RedisSessionDAO integration test was put under org.crazycake.shiro
Expand All @@ -29,6 +32,9 @@ public class RedisSessionDAOIntegrationTest {
private FakeSession emptySession;
private String name1;
private String prefix;
private StringSerializer keySerializer = new StringSerializer();
private ObjectSerializer valueSerializer = new ObjectSerializer();

private void blast() {
blastRedis();
}
Expand Down Expand Up @@ -136,4 +142,28 @@ public void testRemoveExpiredSessionInMemory() throws InterruptedException, Seri
Map<Serializable, SessionInMemory> sessionMap = (Map<Serializable, SessionInMemory>) redisSessionDAO.getSessionsInThread().get();
assertEquals(sessionMap.size(), 1);
}

@Test
public void testTurnOffSessionInMemoryEnabled() throws InterruptedException, SerializationException {
redisSessionDAO.setSessionInMemoryTimeout(2000L);
session1.setCompany("apple");
redisSessionDAO.doCreate(session1);
// Load session into SessionInThread
redisSessionDAO.doReadSession(session1.getId());
// Directly update session in Redis
session1.setCompany("google");
RedisManager redisManager = scaffoldStandaloneRedisManager();
String sessionRedisKey = prefix + session1.getId();
redisManager.set(keySerializer.serialize(sessionRedisKey), valueSerializer.serialize(session1), 10);
// Try to read session again
Thread.sleep(500);
FakeSession sessionFromThreadLocal = (FakeSession)redisSessionDAO.doReadSession(session1.getId());
// The company should be the old value
assertThat(sessionFromThreadLocal.getCompany(), is("apple"));
// Turn off sessionInMemoryEnabled
redisSessionDAO.setSessionInMemoryEnabled(false);
// Try to read session again. It should get the version in Redis
FakeSession sessionFromRedis = (FakeSession)redisSessionDAO.doReadSession(session1.getId());
assertThat(sessionFromRedis.getCompany(), is("google"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ public void testPutString() {
public void testSize() throws InterruptedException {
doPutAuth(redisCache, user1);
doPutAuth(redisCache, user2);
Thread.sleep(800);
assertEquals(redisCache.size(), 2);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
package org.crazycake.shiro.integration;

import com.github.javafaker.Faker;
import org.apache.commons.lang3.math.NumberUtils;
import org.apache.shiro.subject.PrincipalCollection;
import org.apache.shiro.subject.SimplePrincipalCollection;
import org.crazycake.shiro.RedisCacheManager;
import org.crazycake.shiro.RedisManager;
import org.crazycake.shiro.RedisSessionDAO;
import org.crazycake.shiro.exception.SerializationException;
import org.crazycake.shiro.integration.fixture.model.UserInfo;
import org.crazycake.shiro.serializer.ObjectSerializer;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
public class FakeSession extends SimpleSession implements Serializable, Session{
private Integer id;
private String name;
private String company;

public FakeSession() {}

Expand All @@ -35,6 +36,14 @@ public void setName(String name) {
this.name = name;
}

public String getCompany() {
return company;
}

public void setCompany(String company) {
this.company = company;
}

@Override
public Date getStartTimestamp() {
return null;
Expand Down
73 changes: 53 additions & 20 deletions src/main/java/org/crazycake/shiro/RedisSessionDAO.java
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ public class RedisSessionDAO extends AbstractSessionDAO {
*/
@Override
public void update(Session session) throws UnknownSessionException {
if (this.sessionInMemoryEnabled) {
this.removeExpiredSessionInMemory();
}
this.saveSession(session);
if (this.sessionInMemoryEnabled) {
this.setSessionToThreadLocal(session.getId(), session);
Expand Down Expand Up @@ -117,10 +120,16 @@ private void saveSession(Session session) throws UnknownSessionException {
*/
@Override
public void delete(Session session) {
if (this.sessionInMemoryEnabled) {
this.removeExpiredSessionInMemory();
}
if (session == null || session.getId() == null) {
logger.error("session or session id is null");
return;
}
if (this.sessionInMemoryEnabled) {
this.delSessionFromThreadLocal(session.getId());
}
try {
redisManager.del(keySerializer.serialize(getRedisSessionKey(session.getId())));
} catch (SerializationException e) {
Expand All @@ -134,6 +143,9 @@ public void delete(Session session) {
*/
@Override
public Collection<Session> getActiveSessions() {
if (this.sessionInMemoryEnabled) {
this.removeExpiredSessionInMemory();
}
Set<Session> sessions = new HashSet<Session>();
try {
Set<byte[]> keys = redisManager.keys(keySerializer.serialize(this.keyPrefix + "*"));
Expand All @@ -151,11 +163,14 @@ public Collection<Session> getActiveSessions() {

@Override
protected Serializable doCreate(Session session) {
if (this.sessionInMemoryEnabled) {
this.removeExpiredSessionInMemory();
}
if (session == null) {
logger.error("session is null");
throw new UnknownSessionException("session is null");
}
Serializable sessionId = this.generateSessionId(session);
Serializable sessionId = this.generateSessionId(session);
this.assignSessionId(session, sessionId);
this.saveSession(session);
return sessionId;
Expand All @@ -168,47 +183,67 @@ protected Serializable doCreate(Session session) {
*/
@Override
protected Session doReadSession(Serializable sessionId) {
if (this.sessionInMemoryEnabled) {
this.removeExpiredSessionInMemory();
}
if (sessionId == null) {
logger.warn("session id is null");
return null;
}

if (this.sessionInMemoryEnabled) {
Session session = getSessionFromThreadLocal(sessionId);
if (session != null) {
return session;
}
}

Session session = null;
logger.debug("read session from redis");
try {
session = (Session) valueSerializer.deserialize(redisManager.get(keySerializer.serialize(getRedisSessionKey(sessionId))));
String sessionRedisKey = getRedisSessionKey(sessionId);
logger.debug("read session: " + sessionRedisKey + " from Redis");
session = (Session) valueSerializer.deserialize(redisManager.get(keySerializer.serialize(sessionRedisKey)));
if (this.sessionInMemoryEnabled) {
setSessionToThreadLocal(sessionId, session);
}
} catch (SerializationException e) {
logger.error("read session error. sessionId=" + sessionId);
logger.error("read session error. sessionId: " + sessionId);
}
return session;
}

private void setSessionToThreadLocal(Serializable sessionId, Session s) {
private void setSessionToThreadLocal(Serializable sessionId, Session session) {
this.initSessionsInThread();
Map<Serializable, SessionInMemory> sessionMap = (Map<Serializable, SessionInMemory>) sessionsInThread.get();
if (sessionMap == null) {
sessionMap = new HashMap<Serializable, SessionInMemory>();
sessionsInThread.set(sessionMap);
}
sessionMap.put(sessionId, this.createSessionInMemory(session));
}

removeExpiredSessionInMemory(sessionMap);
private void delSessionFromThreadLocal(Serializable sessionId) {
Map<Serializable, SessionInMemory> sessionMap = (Map<Serializable, SessionInMemory>) sessionsInThread.get();
if (sessionMap == null) {
return;
}
sessionMap.remove(sessionId);
}

private SessionInMemory createSessionInMemory(Session session) {
SessionInMemory sessionInMemory = new SessionInMemory();
sessionInMemory.setCreateTime(new Date());
sessionInMemory.setSession(s);
sessionMap.put(sessionId, sessionInMemory);
sessionInMemory.setSession(session);
return sessionInMemory;
}

private void initSessionsInThread() {
Map<Serializable, SessionInMemory> sessionMap = (Map<Serializable, SessionInMemory>) sessionsInThread.get();
if (sessionMap == null) {
sessionMap = new HashMap<Serializable, SessionInMemory>();
sessionsInThread.set(sessionMap);
}
}

private void removeExpiredSessionInMemory(Map<Serializable, SessionInMemory> sessionMap) {
private void removeExpiredSessionInMemory() {
Map<Serializable, SessionInMemory> sessionMap = (Map<Serializable, SessionInMemory>) sessionsInThread.get();
if (sessionMap == null) {
return;
}
Iterator<Serializable> it = sessionMap.keySet().iterator();
while (it.hasNext()) {
Serializable sessionId = it.next();
Expand All @@ -222,6 +257,9 @@ private void removeExpiredSessionInMemory(Map<Serializable, SessionInMemory> ses
it.remove();
}
}
if (sessionMap.size() == 0) {
sessionsInThread.remove();
}
}

private Session getSessionFromThreadLocal(Serializable sessionId) {
Expand All @@ -234,11 +272,6 @@ private Session getSessionFromThreadLocal(Serializable sessionId) {
if (sessionInMemory == null) {
return null;
}
long liveTime = getSessionInMemoryLiveTime(sessionInMemory);
if (liveTime > sessionInMemoryTimeout) {
sessionMap.remove(sessionId);
return null;
}

logger.debug("read session from memory");
return sessionInMemory.getSession();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ public byte[] get(byte[] key) {
* set
* @param key key
* @param value value
* @param expireTime expire time
* @param expireTime expire time in second
* @return value
*/
@Override
Expand Down

0 comments on commit e36d35e

Please sign in to comment.