Skip to content

Commit

Permalink
Fixe Redis Cache driver Issues
Browse files Browse the repository at this point in the history
Refactored Redis Cache Driver
- Fixed internal cache not always loaded on first cace access
- added ability to define “timeout” for for item expiry
- - defaults to 1 month.  Previous hardcoded value was 10 years
- correctly store and unpack int types
- fix internal cache not removed if store was flushed
- fixed an issue where am imternal cache key would not be updated if it existed, but would be updated in the Redis store
  • Loading branch information
circulon committed Jul 24, 2024
1 parent 22c5512 commit 9d6adb6
Showing 1 changed file with 36 additions and 32 deletions.
68 changes: 36 additions & 32 deletions src/masonite/cache/drivers/RedisDriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,54 +9,51 @@ class RedisDriver:
def __init__(self, application):
self.application = application
self.connection = None
self._internal_cache: "dict|None" = None
self.options = None
self._internal_cache: dict = None

def set_options(self, options: dict) -> "RedisDriver":
self.options = options
return self

def get_connection(self) -> "Redis":
if self.connection:
return self.connection

try:
from redis import Redis
except ImportError:
raise ModuleNotFoundError(
"Could not find the 'redis' library. Run 'pip install redis' to fix this."
)

if not self.connection:
self.connection = Redis(
**self.options.get("options", {}),
host=self.options.get("host"),
port=self.options.get("port"),
password=self.options.get("password"),
decode_responses=True,
)

# populate the internal cache the first time
# the connection is established
if self._internal_cache is None and self.connection:
self._load_from_store(self.connection)
self.connection = Redis(
**self.options.get("options", {}),
host=self.options.get("host"),
port=self.options.get("port"),
password=self.options.get("password"),
decode_responses=True,
)

return self.connection

def _load_from_store(self, connection: "Redis" = None) -> None:
def _load_from_store(self) -> None:
"""
copy all the "cache" key value pairs for faster access
"""
if not connection:
if self._internal_cache is not None:
return

if self._internal_cache is None:
self._internal_cache = {}
self._internal_cache = {}

cursor = "0"
prefix = self.get_cache_namespace()
while cursor != 0:
cursor, keys = connection.scan(
cursor, keys = self.get_connection().scan(
cursor=cursor, match=prefix + "*", count=100000
)
if keys:
values = connection.mget(*keys)
values = self.get_connection().mget(*keys)
store_data = dict(zip(keys, values))
for key, value in store_data.items():
key = key.replace(prefix, "")
Expand All @@ -72,15 +69,15 @@ def get_cache_namespace(self) -> str:
return f"{namespace}cache:"

def add(self, key: str, value: Any = None) -> Any:
if not value:
if value is None:
return None

self.put(key, value)
return value

def get(self, key: str, default: Any = None, **options) -> Any:
self._load_from_store()
if default and not self.has(key):
self.put(key, default, **options)
return default

return self._internal_cache.get(key)
Expand All @@ -89,20 +86,22 @@ def put(self, key: str, value: Any = None, seconds: int = None, **options) -> An
if not key or value is None:
return None

time = self.get_expiration_time(seconds)
time = seconds or self.get_default_timeout()

store_value = value
if isinstance(value, (dict, list, tuple)):
store_value = json.dumps(value)
elif isinstance(value, int):
store_value = str(value)

self._load_from_store()
self.get_connection().set(
f"{self.get_cache_namespace()}{key}", store_value, ex=time
)

if not self.has(key):
self._internal_cache.update({key: value})
self._internal_cache.update({key: value})

def has(self, key: str) -> bool:
self._load_from_store()
return key in self._internal_cache

def increment(self, key: str, amount: int = 1) -> int:
Expand All @@ -126,22 +125,27 @@ def remember(self, key: str, callable):
return self.get(key)

def forget(self, key: str) -> None:
if not self.has(key):
return
self.get_connection().delete(f"{self.get_cache_namespace()}{key}")
self._internal_cache.pop(key)

def flush(self) -> None:
return self.get_connection().flushall()
flushed = self.get_connection().flushall()
if flushed:
self._internal_cache = None

def get_expiration_time(self, seconds: int) -> int:
if seconds is None:
seconds = 31557600 * 10
return flushed

return seconds
def get_default_timeout(self) -> int:
# if unset default timeout of cache vars is 1 month
return int(self.options.get("timeout", 60 * 60 * 24 * 30))

def unpack_value(self, value: Any) -> Any:
value = str(value)
if value.isdigit():
return str(value)
return int(value)

try:
return json.loads(value)
except json.decoder.JSONDecodeError:
Expand Down

0 comments on commit 9d6adb6

Please sign in to comment.