diff --git a/src/main/java/org/crazycake/shiro/BaseRedisManager.java b/src/main/java/org/crazycake/shiro/BaseRedisManager.java index 1d4747b0a..c66c08dc0 100644 --- a/src/main/java/org/crazycake/shiro/BaseRedisManager.java +++ b/src/main/java/org/crazycake/shiro/BaseRedisManager.java @@ -105,11 +105,20 @@ public void del(byte[] key) { * Return the size of redis db. */ @Override - public Long dbSize() { - Long dbSize = 0L; + public Long dbSize(byte[] pattern) { + long dbSize = 0L; Jedis jedis = getJedis(); try { - dbSize = jedis.dbSize(); + ScanParams params = new ScanParams(); + params.count(count); + params.match(pattern); + byte[] cursor = ScanParams.SCAN_POINTER_START_BINARY; + ScanResult scanResult; + do { + scanResult = jedis.scan(cursor, params); + dbSize++; + cursor = scanResult.getCursorAsBytes(); + } while (scanResult.getStringCursor().compareTo(ScanParams.SCAN_POINTER_START) > 0); } finally { jedis.close(); } diff --git a/src/main/java/org/crazycake/shiro/IRedisManager.java b/src/main/java/org/crazycake/shiro/IRedisManager.java index eb9f7b573..36123a46e 100644 --- a/src/main/java/org/crazycake/shiro/IRedisManager.java +++ b/src/main/java/org/crazycake/shiro/IRedisManager.java @@ -33,7 +33,7 @@ public interface IRedisManager { /** * size */ - Long dbSize(); + Long dbSize(byte[] pattern); /** * keys diff --git a/src/main/java/org/crazycake/shiro/RedisCache.java b/src/main/java/org/crazycake/shiro/RedisCache.java index 03e8f07a1..cecff3b5b 100644 --- a/src/main/java/org/crazycake/shiro/RedisCache.java +++ b/src/main/java/org/crazycake/shiro/RedisCache.java @@ -185,7 +185,12 @@ public void clear() throws CacheException { @Override public int size() { - Long longSize = new Long(redisManager.dbSize()); + Long longSize = 0L; + try { + longSize = new Long(redisManager.dbSize(keySerializer.serialize(this.keyPrefix + "*"))); + } catch (SerializationException e) { + logger.error("get keys error", e); + } return longSize.intValue(); } diff --git a/src/main/java/org/crazycake/shiro/RedisClusterManager.java b/src/main/java/org/crazycake/shiro/RedisClusterManager.java index d3243db4d..2d56f7142 100644 --- a/src/main/java/org/crazycake/shiro/RedisClusterManager.java +++ b/src/main/java/org/crazycake/shiro/RedisClusterManager.java @@ -61,6 +61,7 @@ protected JedisCluster getJedisCluster() { return jedisCluster; } + @Override public byte[] get(byte[] key) { if (key == null) { return null; @@ -68,6 +69,7 @@ public byte[] get(byte[] key) { return getJedisCluster().get(key); } + @Override public byte[] set(byte[] key, byte[] value, int expireTime) { if (key == null) { return null; @@ -79,6 +81,7 @@ public byte[] set(byte[] key, byte[] value, int expireTime) { return value; } + @Override public void del(byte[] key) { if (key == null) { return; @@ -86,23 +89,23 @@ public void del(byte[] key) { getJedisCluster().del(key); } - public Long dbSize() { + @Override + public Long dbSize(byte[] pattern) { Long dbSize = 0L; Map clusterNodes = getJedisCluster().getClusterNodes(); - for (String k : clusterNodes.keySet()) { - JedisPool jp = clusterNodes.get(k); - Jedis connection = jp.getResource(); - try { - dbSize += connection.dbSize(); - } catch (Exception e) { - e.printStackTrace(); - } finally { - connection.close(); + Iterator> nodeIt = clusterNodes.entrySet().iterator(); + while (nodeIt.hasNext()) { + Map.Entry node = nodeIt.next(); + long nodeDbSize = getDbSizeFromClusterNode(node.getValue(), pattern); + if (nodeDbSize == 0L) { + continue; } + dbSize += nodeDbSize; } return dbSize; } + @Override public Set keys(byte[] pattern) { Set keys = new HashSet(); Map clusterNodes = getJedisCluster().getClusterNodes(); @@ -140,6 +143,27 @@ private Set getKeysFromClusterNode(JedisPool jedisPool, byte[] pattern) return keys; } + private long getDbSizeFromClusterNode(JedisPool jedisPool, byte[] pattern) { + long dbSize = 0L; + Jedis jedis = jedisPool.getResource(); + + try { + ScanParams params = new ScanParams(); + params.count(count); + params.match(pattern); + byte[] cursor = ScanParams.SCAN_POINTER_START_BINARY; + ScanResult scanResult; + do { + scanResult = jedis.scan(cursor, params); + dbSize++; + cursor = scanResult.getCursorAsBytes(); + } while (scanResult.getStringCursor().compareTo(ScanParams.SCAN_POINTER_START) > 0); + } finally { + jedis.close(); + } + return dbSize; + } + public int getMaxAttempts() { return maxAttempts; } diff --git a/src/test/java/org/crazycake/shiro/RedisCacheTest.java b/src/test/java/org/crazycake/shiro/RedisCacheTest.java index 947abc4f8..53fc83b83 100644 --- a/src/test/java/org/crazycake/shiro/RedisCacheTest.java +++ b/src/test/java/org/crazycake/shiro/RedisCacheTest.java @@ -82,8 +82,8 @@ public void testRedisCache() { } @Test - public void testSize() { - when(redisManager.dbSize()).thenReturn(2L); + public void testSize() throws SerializationException { + when(redisManager.dbSize(keySerializer.serialize(testPrefix + "*"))).thenReturn(2L); assertThat(redisCache.size(), is(2)); } diff --git a/src/test/java/org/crazycake/shiro/RedisManagerTest.java b/src/test/java/org/crazycake/shiro/RedisManagerTest.java index 1eb86454e..d8d76d24e 100644 --- a/src/test/java/org/crazycake/shiro/RedisManagerTest.java +++ b/src/test/java/org/crazycake/shiro/RedisManagerTest.java @@ -15,7 +15,6 @@ import java.util.Set; import static org.hamcrest.CoreMatchers.is; -import static org.hamcrest.CoreMatchers.nullValue; import static org.junit.Assert.assertThat; import static org.mockito.Mockito.*; @@ -85,17 +84,6 @@ public void testDel() throws SerializationException { verify(jedis, times(1)).del(any((new byte[0]).getClass())); } - @Test - public void testDbSize() { - when(jedis.dbSize()).thenReturn(3L); - Long actualDbSize = redisManager.dbSize(); - assertThat(actualDbSize, is(3L)); - - when(jedis.dbSize()).thenReturn(null); - actualDbSize = redisManager.dbSize(); - assertThat(actualDbSize, is(nullValue())); - } - @Test public void testKeys() throws SerializationException { ScanResult scanResult = mock(ScanResult.class);