Skip to content

Commit

Permalink
add default argument to Moka.get
Browse files Browse the repository at this point in the history
  • Loading branch information
Roman Kitaev committed Dec 2, 2024
1 parent 5da2d6f commit 4f8c8f2
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 12 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "moka-py"
version = "0.1.9"
version = "0.1.10"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
Expand Down
15 changes: 9 additions & 6 deletions moka_py/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import asyncio
from functools import wraps, _make_key
from .moka_py import Moka
from .moka_py import Moka, get_version


__all__ = ["Moka", "cached"]
__all__ = ["Moka", "cached", "VERSION"]

VERSION = get_version()


def cached(maxsize=128, typed=False, *, ttl=None, tti=None, wait_concurrent=False):
cache = Moka(maxsize, ttl, tti)
empty = object()

def dec(fn):
if asyncio.iscoroutinefunction(fn):
Expand All @@ -17,8 +20,8 @@ def dec(fn):
@wraps(fn)
async def inner(*args, **kwargs):
key = _make_key(args, kwargs, typed)
maybe_value = cache.get(key)
if maybe_value is not None:
maybe_value = cache.get(key, empty)
if maybe_value is not empty:
return maybe_value
value = await fn(*args, **kwargs)
cache.set(key, value)
Expand All @@ -30,8 +33,8 @@ def inner(*args, **kwargs):
if wait_concurrent:
return cache.get_with(key, lambda: fn(*args, **kwargs))
else:
maybe_value = cache.get(key)
if maybe_value is not None:
maybe_value = cache.get(key, empty)
if maybe_value is not empty:
return maybe_value
value = fn(*args, **kwargs)
cache.set(key, value)
Expand Down
9 changes: 7 additions & 2 deletions moka_py/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import TypeVar, Optional, Generic, Hashable, Union, Callable, Any
from typing import TypeVar, Optional, Generic, Hashable, Union, Callable, Any, overload


K = TypeVar("K", bound=Hashable)
V = TypeVar("V")
D = TypeVar("D")
Fn = TypeVar("Fn", bound=Callable[..., Any])


Expand All @@ -16,7 +17,11 @@ class Moka(Generic[K, V]):

def set(self, key: K, value: V) -> None: ...

def get(self, key: K) -> Optional[V]: ...
@overload
def get(self, key: K, default: D) -> Union[V, D]: ...

@overload
def get(self, key: K, default: Optional[D] = None) -> Optional[Union[V, D]]: ...

def get_with(self, key: K, initializer: Callable[[], V]) -> V:
"""
Expand Down
19 changes: 17 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,19 @@ impl Moka {
Ok(())
}

fn get(&self, py: Python, key: PyObject) -> PyResult<Option<PyObject>> {
#[pyo3(signature = (key, default=None))]
fn get(
&self,
py: Python,
key: PyObject,
default: Option<PyObject>,
) -> PyResult<Option<PyObject>> {
let hashable_key = AnyKey::new_with_gil(key, py)?;
let value = py.allow_threads(|| self.0.get(&hashable_key));
Ok(value.map(|obj| obj.clone_ref(py)))
Ok(match value.map(|obj| obj.clone_ref(py)) {
None => default.map(|v| v.clone_ref(py)),
Some(v) => Some(v),
})
}

fn get_with(&self, py: Python, key: PyObject, initializer: PyObject) -> PyResult<PyObject> {
Expand Down Expand Up @@ -155,8 +164,14 @@ impl Moka {
}
}

#[pyfunction]
fn get_version() -> &'static str {
env!("CARGO_PKG_VERSION")
}

#[pymodule]
fn moka_py(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<Moka>()?;
m.add_function(wrap_pyfunction!(get_version, m)?)?;
Ok(())
}
8 changes: 8 additions & 0 deletions tests/test_moka.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
from time import monotonic, sleep
import threading

import moka_py


Expand Down Expand Up @@ -78,3 +79,10 @@ def target():
t2.join()

assert len(calls) == 1


def test_default() -> None:
moka = moka_py.Moka(128)
moka.set("hello", [1, 2, 3])
assert moka.get("world") is None
assert moka.get("world", "foo") == "foo"

0 comments on commit 4f8c8f2

Please sign in to comment.