Skip to content

Commit

Permalink
add eviction_listener callback
Browse files Browse the repository at this point in the history
  • Loading branch information
deliro committed Dec 5, 2024
1 parent 053bcca commit 4e0d03b
Show file tree
Hide file tree
Showing 6 changed files with 179 additions and 24 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "moka-py"
version = "0.1.11"
version = "0.1.12"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
Expand Down
54 changes: 54 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ async def f(x, y):
await asyncio.sleep(2.0)
return x + y


start = perf_counter()
assert asyncio.run(f(5, 6)) == 11
assert asyncio.run(f(5, 6)) == 11 # got from cache
Expand Down Expand Up @@ -138,6 +139,59 @@ if __name__ == '__main__':

> **_ATTENTION:_** `wait_concurrent` is not yet supported for async functions and will throw `NotImplementedError`
## Eviction listener

moka-py supports adding of an eviction listener that's called whenever a key is dropped
from the cache for some reason. The listener must be a 3-arguments function `(key, value, cause)`. The arguments
are passed as positional (not keyword).

There are 4 reasons why a key may be dropped:

1. `"expired"`: The entry's expiration timestamp has passed.
2. `"explicit"`: The entry was manually removed by the user (`.remove()` is called).
3. `"replaced"`: The entry itself was not actually removed, but its value was replaced by the user (`.set()` is
called for an existing entry).
4. `"size"`: The entry was evicted due to size constraints.

```python
from typing import Literal
from moka_py import Moka
from time import sleep


def key_evicted(
k: str,
v: list[int],
cause: Literal["explicit", "size", "expired", "replaced"]
):
print(f"entry {k}:{v} was evicted. {cause=}")


moka: Moka[str, list[int]] = Moka(2, eviction_listener=key_evicted, ttl=0.1)
moka.set("hello", [1, 2, 3])
moka.set("hello", [3, 2, 1])
moka.set("foo", [4])
moka.set("bar", [])
sleep(1)
moka.get("foo")

# will print
# entry hello:[1, 2, 3] was evicted. cause='replaced'
# entry bar:[] was evicted. cause='size'
# entry hello:[3, 2, 1] was evicted. cause='expired'
# entry foo:[4] was evicted. cause='expired'
```

> **_IMPORTANT NOTES_**:
> 1. It's not guaranteed that the listener will be called just in time. Also, the underlying `moka` doesn't use any
background threads or tasks, hence, the listener is never called in "background"
> 2. The listener must never raise any kind of `Exception`. If an exception is raised, it might be raised to any of the
moka-py method in any of the threads that call this method.
> 3. The listener must be fast. Since it's called only when you're interacting with `moka-py` (via `.get()` / `.set()` /
etc.), the listener will slow down these operations. It's terrible idea to do some sort of IO in the listener. If
you need so, run a `ThreadPoolExecutor` somewhere and call `.submit()` inside of the listener or commit an async
task via `asyncio.create_task()`

## Performance

*Measured using MacBook Pro 2021 with Apple M1 Pro processor and 16GiB RAM*
Expand Down
4 changes: 3 additions & 1 deletion moka_py/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import TypeVar, Optional, Generic, Hashable, Union, Callable, Any, overload
from typing import TypeVar, Optional, Generic, Hashable, Union, Callable, Any, overload, Literal


K = TypeVar("K", bound=Hashable)
V = TypeVar("V")
D = TypeVar("D")
Fn = TypeVar("Fn", bound=Callable[..., Any])
Cause = Literal["explicit", "size", "expired", "replaced"]


class Moka(Generic[K, V]):
Expand All @@ -13,6 +14,7 @@ class Moka(Generic[K, V]):
capacity: int,
ttl: Optional[Union[int, float]] = None,
tti: Optional[Union[int, float]] = None,
eviction_listener: Optional[Callable[[K, V, Cause], None]] = None,
): ...

def set(self, key: K, value: V) -> None: ...
Expand Down
93 changes: 73 additions & 20 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@ use std::hash::{Hash, Hasher};
use std::sync::Arc;
use std::time::Duration;

use moka::notification::RemovalCause;
use moka::sync::Cache;
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::pyclass::CompareOp;
use pyo3::types::{PyString, PyType};

#[derive(Debug)]
enum AnyKey {
enum KeyKind {
/// String keys are the most common. If the string is short enough,
/// we can get faster and more freedom from GIL by copying a string
/// to Rust and hashing it using `ahash` instead of calling
Expand All @@ -20,47 +21,71 @@ enum AnyKey {
ShortStr(String),

/// Other keys (even long Python strings) go this (slower) way
Other(PyObject, isize),
Other { py_hash: isize },
}

#[derive(Debug)]
struct AnyKey {
obj: PyObject,
kind: KeyKind,
}

impl AnyKey {
const SHORT_STR: usize = 256;

#[inline]
fn new_with_gil(obj: PyObject, py: Python) -> PyResult<Self> {
if let Ok(s) = obj.downcast_bound::<PyString>(py) {
if s.len()? <= Self::SHORT_STR {
return Ok(AnyKey::ShortStr(s.to_string()));
let kind = match obj.downcast_bound::<PyString>(py) {
Ok(s) if s.len()? <= Self::SHORT_STR => KeyKind::ShortStr(s.to_string()),
_ => {
let py_hash = obj.to_object(py).into_bound(py).hash()?;
KeyKind::Other { py_hash }
}
}
let hash = obj.to_object(py).into_bound(py).hash()?;
Ok(AnyKey::Other(obj, hash))
};
Ok(AnyKey { obj, kind })
}
}

impl PartialEq for AnyKey {
#[inline]
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(AnyKey::ShortStr(lhs), AnyKey::ShortStr(rhs)) => lhs == rhs,
(
AnyKey {
kind: KeyKind::ShortStr(lhs),
..
},
AnyKey {
kind: KeyKind::ShortStr(rhs),
..
},
) => lhs == rhs,

// It is expected that `hash` will be stable for an object. Hence, since we already
// know both objects' hashes, we can claim that if their hashes are different,
// the objects aren't equal. Only if the hashes are the same, the objects
// might be equal, and only in that case we raise the GIL to run Python
// rich comparison.
(AnyKey::Other(lhs, lhs_hash), AnyKey::Other(rhs, rhs_hash)) => {
(
AnyKey {
kind: KeyKind::Other { py_hash: lhs_hash },
obj: lhs_obj,
},
AnyKey {
kind: KeyKind::Other { py_hash: rhs_hash },
obj: rhs_obj,
},
) => {
*lhs_hash == *rhs_hash
&& Python::with_gil(|py| {
let lhs = lhs.to_object(py).into_bound(py);
let rhs = rhs.to_object(py).into_bound(py);
let lhs = lhs_obj.to_object(py).into_bound(py);
let rhs = rhs_obj.to_object(py).into_bound(py);
match lhs.rich_compare(rhs, CompareOp::Eq) {
Ok(v) => v.is_truthy().unwrap_or_default(),
Err(_) => false,
}
})
}

_ => false,
}
}
Expand All @@ -70,39 +95,67 @@ impl Eq for AnyKey {}
impl Hash for AnyKey {
#[inline]
fn hash<H: Hasher>(&self, state: &mut H) {
match self {
AnyKey::ShortStr(s) => s.hash(state),
AnyKey::Other(_, hash) => hash.hash(state),
match &self.kind {
KeyKind::ShortStr(s) => s.hash(state),
KeyKind::Other { py_hash } => py_hash.hash(state),
}
}
}

#[inline]
fn cause_to_str(cause: RemovalCause) -> &'static str {
match cause {
RemovalCause::Expired => "expired",
RemovalCause::Explicit => "explicit",
RemovalCause::Replaced => "replaced",
RemovalCause::Size => "size",
}
}

#[pyclass]
struct Moka(Arc<Cache<AnyKey, Arc<PyObject>, ahash::RandomState>>);

#[pymethods]
impl Moka {
#[new]
#[pyo3(signature = (capacity, ttl=None, tti=None))]
fn new(capacity: u64, ttl: Option<f64>, tti: Option<f64>) -> PyResult<Self> {
#[pyo3(signature = (capacity, ttl=None, tti=None, eviction_listener=None))]
fn new(
capacity: u64,
ttl: Option<f64>,
tti: Option<f64>,
eviction_listener: Option<PyObject>,
) -> PyResult<Self> {
let mut builder = Cache::builder().max_capacity(capacity);

if let Some(ttl) = ttl {
let ttl_micros = (ttl * 1000_000.0) as u64;
let ttl_micros = (ttl * 1_000_000.0) as u64;
if ttl_micros == 0 {
return Err(PyValueError::new_err("ttl must be positive"));
}
builder = builder.time_to_live(Duration::from_micros(ttl_micros));
}

if let Some(tti) = tti {
let tti_micros = (tti * 1000_000.0) as u64;
let tti_micros = (tti * 1_000_000.0) as u64;
if tti_micros == 0 {
return Err(PyValueError::new_err("tti must be positive"));
}
builder = builder.time_to_idle(Duration::from_micros(tti_micros));
}

if let Some(listener) = eviction_listener {
let listen_fn = move |k: Arc<AnyKey>, v: Arc<PyObject>, cause: RemovalCause| {
Python::with_gil(|py| {
let key = k.as_ref().obj.clone_ref(py);
let value = v.as_ref().clone_ref(py);
if let Err(e) = listener.call1(py, (key, value, cause_to_str(cause))) {
e.restore(py)
}
});
};
builder = builder.eviction_listener(Box::new(listen_fn));
}

Ok(Moka(Arc::new(
builder.build_with_hasher(ahash::RandomState::default()),
)))
Expand Down
48 changes: 47 additions & 1 deletion tests/test_moka.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
from time import monotonic, sleep
import threading
from concurrent.futures import ThreadPoolExecutor
from time import monotonic, sleep, perf_counter

import moka_py

Expand Down Expand Up @@ -114,3 +115,48 @@ def target_2():

t1.join()
t2.join()


def test_eviction_listener():
evicted = []

def listener(k, v, cause):
evicted.append((k, v, cause))

moka = moka_py.Moka(3, eviction_listener=listener, ttl=0.1)
moka.set("hello", "world")
moka.set("hello", "REPLACED")
moka.remove("hello")
moka.set("foo", "bar")
for i in range(10):
moka.set(f"out-of-size-{i}", 123)
sleep(1)
assert moka.get("foo") is None
assert {cause for _, _, cause in evicted} == {"size", "explicit", "expired", "replaced"}


def test_eviction_listener_io():
pool = ThreadPoolExecutor()
evicted = []
ev = threading.Event()

def slow_io(k, v, cause):
sleep(1)
evicted.append((k, v, cause))
ev.set()

def listener(k, v, cause):
pool.submit(slow_io, k, v, cause)

moka = moka_py.Moka(3, eviction_listener=listener, ttl=0.1)
moka.set("hello", "world")
sleep(0.5)

start = perf_counter()
moka.get("hello")
duration = perf_counter() - start
assert duration < 1.0

# wait until the thread pool add the message
ev.wait(2.0)
assert len(evicted) == 1

0 comments on commit 4e0d03b

Please sign in to comment.