🚧 api: new "cached" decorator implementation
This commit is contained in:
parent
0073e72f9c
commit
5e83b58b32
5 changed files with 279 additions and 121 deletions
|
|
@ -1,113 +0,0 @@
|
|||
import inspect
|
||||
from base64 import urlsafe_b64encode
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from functools import wraps
|
||||
from itertools import chain
|
||||
from typing import Iterable, ParamSpec, TypeVar
|
||||
|
||||
from redis import Redis
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True, slots=True)
|
||||
class Config:
|
||||
redis: Redis
|
||||
prefix: str | None = "cache"
|
||||
ttl_fresh: int = 600
|
||||
ttl_stale: int | None = None
|
||||
|
||||
|
||||
def qualified_name(callable_obj) -> str:
|
||||
# callable classes/instances
|
||||
if hasattr(callable_obj, "__call__") and not inspect.isroutine(callable_obj):
|
||||
# callable instance: use its class
|
||||
cls = callable_obj.__class__
|
||||
module = getattr(cls, "__module__", None)
|
||||
qual = getattr(cls, "__qualname__", cls.__name__)
|
||||
return f"{module}.{qual}" if module else qual
|
||||
|
||||
# functions, methods, builtins
|
||||
# unwrap descriptors like staticmethod/classmethod
|
||||
if isinstance(callable_obj, staticmethod | classmethod):
|
||||
callable_obj = callable_obj.__func__
|
||||
|
||||
# bound method
|
||||
if inspect.ismethod(callable_obj):
|
||||
func = callable_obj.__func__
|
||||
owner = getattr(func, "__qualname__", func.__name__).rsplit(".", 1)[0]
|
||||
module = getattr(func, "__module__", None)
|
||||
qual = f"{owner}.{func.__name__}"
|
||||
return f"{module}.{qual}" if module else qual
|
||||
|
||||
# regular function or builtin
|
||||
if inspect.isfunction(callable_obj) or inspect.isbuiltin(callable_obj):
|
||||
module = getattr(callable_obj, "__module__", None)
|
||||
qual = getattr(callable_obj, "__qualname__", callable_obj.__name__)
|
||||
return f"{module}.{qual}" if module else qual
|
||||
|
||||
# fallback for other callables (functors, functools.partial, etc.)
|
||||
try:
|
||||
module = getattr(callable_obj, "__module__", None)
|
||||
qual = getattr(callable_obj, "__qualname__", None) or getattr(
|
||||
callable_obj, "__name__", type(callable_obj).__name__
|
||||
)
|
||||
return f"{module}.{qual}" if module else qual
|
||||
|
||||
except Exception:
|
||||
return urlsafe_b64encode(repr(callable_obj).encode("utf-8")).decode("utf-8")
|
||||
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def args_slice(func: Callable[P, R], *args: Iterable) -> tuple:
|
||||
if hasattr(func, "__call__") and not inspect.isroutine(func):
|
||||
return args[1:]
|
||||
|
||||
if isinstance(func, staticmethod | classmethod):
|
||||
func = func.__func__
|
||||
|
||||
if inspect.ismethod(func):
|
||||
return args[1:]
|
||||
|
||||
return tuple(*args)
|
||||
|
||||
|
||||
def cache_key(func: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> str:
|
||||
"""Return a cache key for use with cached methods."""
|
||||
|
||||
kwargs_by_key = sorted(kwargs.items(), key=lambda kv: kv[0])
|
||||
|
||||
parts = chain(
|
||||
(qualified_name(func),),
|
||||
# positional args
|
||||
(repr(arg) for arg in args_slice(func, args)),
|
||||
# keyword args
|
||||
(f"{k}={v!r}" for k, v in kwargs_by_key),
|
||||
)
|
||||
|
||||
return ":".join(parts)
|
||||
|
||||
|
||||
def redis_cached(cfg: Config) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
""" """
|
||||
|
||||
def decorator(func: Callable[P, R]) -> Callable[P, R]:
|
||||
""" """
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
key = cache_key(func, *args, **kwargs)
|
||||
|
||||
if cfg.prefix is not None:
|
||||
key = f"{cfg.prefix}:{key}"
|
||||
|
||||
# pre-hook
|
||||
result = func(*args, **kwargs)
|
||||
# post-hook
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
|
@ -2,12 +2,14 @@ import logging
|
|||
import re
|
||||
from dataclasses import dataclass
|
||||
from io import BytesIO
|
||||
from typing import Any, Callable
|
||||
|
||||
from asyncify import asyncify
|
||||
from cachetools import cachedmethod
|
||||
from fastapi import BackgroundTasks
|
||||
from redis import Redis
|
||||
|
||||
from .helpers import RedisCache, WebDAVclient, davkey
|
||||
from ...redis_cache import JobsQueue, RedisCache, cached
|
||||
from .helpers import WebDAVclient
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -19,11 +21,24 @@ class Settings:
|
|||
password: str = "s3cr3t!"
|
||||
|
||||
|
||||
class FastAPIQueue:
|
||||
_tasks: BackgroundTasks
|
||||
|
||||
def enqueue(self, task: Callable, *args: Any, **kwargs: Any) -> None:
|
||||
self._tasks.add_task(task, args=args, kwargs=kwargs)
|
||||
|
||||
|
||||
class WebDAV:
|
||||
_webdav_client: WebDAVclient
|
||||
_cache: RedisCache
|
||||
|
||||
def __init__(self, settings: Settings, redis: Redis, ttl_sec: int) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
settings: Settings,
|
||||
redis: Redis,
|
||||
tasks: JobsQueue,
|
||||
ttl_sec: int,
|
||||
) -> None:
|
||||
try:
|
||||
self._webdav_client = WebDAVclient(
|
||||
{
|
||||
|
|
@ -37,10 +52,10 @@ class WebDAV:
|
|||
except AssertionError:
|
||||
raise RuntimeError("WebDAV connection failed!")
|
||||
|
||||
self._cache = RedisCache(cache=redis, ttl=ttl_sec)
|
||||
self._cache = RedisCache(redis=redis, tasks=tasks, ttl_fresh=ttl_sec)
|
||||
|
||||
@asyncify
|
||||
@cachedmethod(cache=lambda self: self._cache, key=davkey("list_files"))
|
||||
@cached(lambda self: self._cache)
|
||||
def _list_files(self, directory: str = "") -> list[str]:
|
||||
"""
|
||||
List files in directory `directory` matching RegEx `regex`
|
||||
|
|
@ -60,7 +75,7 @@ class WebDAV:
|
|||
return [path for path in ls if regex.search(path)]
|
||||
|
||||
@asyncify
|
||||
@cachedmethod(cache=lambda self: self._cache, key=davkey("exists"))
|
||||
@cached(lambda self: self._cache)
|
||||
def exists(self, path: str) -> bool:
|
||||
"""
|
||||
`True` iff there is a WebDAV resource at `path`
|
||||
|
|
@ -70,7 +85,7 @@ class WebDAV:
|
|||
return self._webdav_client.check(path)
|
||||
|
||||
@asyncify
|
||||
@cachedmethod(cache=lambda self: self._cache, key=davkey("read_bytes"))
|
||||
@cached(lambda self: self._cache)
|
||||
def read_bytes(self, path: str) -> bytes:
|
||||
"""
|
||||
Load WebDAV file from `path` as bytes
|
||||
|
|
@ -102,7 +117,7 @@ class WebDAV:
|
|||
|
||||
# invalidate cache entry
|
||||
# begin slice at 0 (there is no "self" argument)
|
||||
del self._cache[davkey("read_bytes", slice(0, None))(path)]
|
||||
# del self._cache[davkey("read_bytes", slice(0, None))(path)]
|
||||
|
||||
async def write_str(self, path: str, content: str, encoding="utf-8") -> None:
|
||||
"""
|
||||
|
|
|
|||
7
api/advent22_api/redis_cache/__init__.py
Normal file
7
api/advent22_api/redis_cache/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
from .cached import JobsQueue, RedisCache, cached
|
||||
|
||||
__all__ = [
|
||||
"JobsQueue",
|
||||
"RedisCache",
|
||||
"cached",
|
||||
]
|
||||
186
api/advent22_api/redis_cache/cached.py
Normal file
186
api/advent22_api/redis_cache/cached.py
Normal file
|
|
@ -0,0 +1,186 @@
|
|||
import functools
|
||||
import json
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Awaitable, Callable, Protocol
|
||||
|
||||
from redis import Redis
|
||||
|
||||
from .helpers import build_cache_key
|
||||
|
||||
|
||||
class JobsQueue(Protocol):
|
||||
def enqueue(self, task: Callable, *args: Any, **kwargs: Any) -> None: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True, slots=True)
|
||||
class RedisCache:
|
||||
"""
|
||||
Container for Redis-backed caching configuration.
|
||||
|
||||
Attributes:
|
||||
redis: Redis client instance.
|
||||
tasks: Background job queue used to refresh stale entries.
|
||||
prefix: Optional key prefix (defaults to "cache").
|
||||
ttl_fresh: TTL in seconds for the freshness marker (how long a value is considered fresh).
|
||||
ttl_stale: TTL in seconds for the cached payload. If None, payload does not expire.
|
||||
ttl_stale: TTL in seconds for the stampede protection lock.
|
||||
"""
|
||||
|
||||
redis: Redis
|
||||
tasks: JobsQueue
|
||||
prefix: str | None = "cache"
|
||||
ttl_fresh: int = 600
|
||||
ttl_stale: int | None = None
|
||||
ttl_lock: int = 5
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True, slots=True)
|
||||
class CachedValue[T]:
|
||||
"""Wrapper for cached content plus freshness flag."""
|
||||
|
||||
content: T
|
||||
fresh: bool
|
||||
|
||||
|
||||
def cached(cache: RedisCache):
|
||||
"""
|
||||
Decorator factory to cache function results in Redis with a freshness marker.
|
||||
On miss, uses a short-lock (SETNX) to ensure only one process recomputes.
|
||||
When value is stale, returns the stale value and queues a background recompute.
|
||||
|
||||
Accepts either:
|
||||
- a RedisCache instance, or
|
||||
- a callable taking the instance (self) and returning RedisCache.
|
||||
|
||||
If a callable is given, the cache is resolved at call-time using the method's `self`.
|
||||
"""
|
||||
|
||||
def decorator[T](func: Callable[..., T]) -> Callable[..., CachedValue[T]]:
|
||||
# Keys used in Redis:
|
||||
# - "<key>:val" -> JSON-serialized value
|
||||
# - "<key>:fresh" -> existence means fresh (string "1"), TTL = ttl_fresh
|
||||
# - "<key>:lock" -> short-lived lock to prevent stampede
|
||||
|
||||
def _redis_val_key(k: str) -> str:
|
||||
return f"{k}:val"
|
||||
|
||||
def _redis_fresh_key(k: str) -> str:
|
||||
return f"{k}:fresh"
|
||||
|
||||
def _redis_lock_key(k: str) -> str:
|
||||
return f"{k}:lock"
|
||||
|
||||
def _serialize(v: Any) -> str:
|
||||
return json.dumps(v, default=str)
|
||||
|
||||
def _deserialize(s: bytes | str | None) -> Any:
|
||||
if s is None:
|
||||
return None
|
||||
|
||||
return json.loads(s)
|
||||
|
||||
def recompute_value(*args: Any, **kwargs: Any) -> None:
|
||||
"""
|
||||
Recompute the function result and store in Redis.
|
||||
This function is intentionally designed to be enqueued into a background job queue.
|
||||
"""
|
||||
# Compute outside of any Redis lock to avoid holding locks during heavy computation.
|
||||
result = func(*args, **kwargs)
|
||||
full_key = build_cache_key(func, cache.prefix, *args, **kwargs)
|
||||
val_key = _redis_val_key(full_key)
|
||||
fresh_key = _redis_fresh_key(full_key)
|
||||
|
||||
# Store payload (with optional ttl_stale)
|
||||
payload = _serialize(result)
|
||||
if cache.ttl_stale is None:
|
||||
# No expiry for payload
|
||||
cache.redis.set(val_key, payload)
|
||||
else:
|
||||
cache.redis.setex(val_key, cache.ttl_stale, payload)
|
||||
|
||||
# Create freshness marker with ttl_fresh
|
||||
cache.redis.setex(fresh_key, cache.ttl_fresh, "1")
|
||||
|
||||
# Ensure lock removed if present (best-effort)
|
||||
try:
|
||||
cache.redis.delete(_redis_lock_key(full_key))
|
||||
except Exception:
|
||||
# swallow: background job should not crash for Redis delete errors
|
||||
pass
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> CachedValue[T]:
|
||||
"""
|
||||
Attempt to return cached value.
|
||||
If missing entirely, try to acquire a short lock and recompute synchronously;
|
||||
otherwise wait briefly for the recompute or return a fallback (None).
|
||||
If present but stale, return the stale value and enqueue background refresh.
|
||||
"""
|
||||
full_key = build_cache_key(func, cache.prefix, *args, **kwargs)
|
||||
val_key = _redis_val_key(full_key)
|
||||
fresh_key = _redis_fresh_key(full_key)
|
||||
lock_key = _redis_lock_key(full_key)
|
||||
|
||||
# Try to get payload and freshness marker
|
||||
raw_payload = cache.redis.get(val_key)
|
||||
is_fresh = cache.redis.exists(fresh_key) == 1
|
||||
|
||||
if raw_payload is None:
|
||||
# Cache miss. Try to acquire lock to recompute synchronously.
|
||||
if cache.redis.setnx(lock_key, "1") == 1:
|
||||
# Ensure lock TTL so it doesn't persist forever
|
||||
cache.redis.expire(lock_key, cache.ttl_lock)
|
||||
|
||||
try:
|
||||
# Recompute synchronously and store
|
||||
value = func(*args, **kwargs)
|
||||
payload = _serialize(value)
|
||||
if cache.ttl_stale is not None:
|
||||
cache.redis.setex(val_key, cache.ttl_stale, payload)
|
||||
else:
|
||||
cache.redis.set(val_key, payload)
|
||||
|
||||
cache.redis.setex(fresh_key, cache.ttl_fresh, "1")
|
||||
return CachedValue(content=value, fresh=True)
|
||||
finally:
|
||||
# Release lock (best-effort)
|
||||
try:
|
||||
cache.redis.delete(lock_key)
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
# Another process is recomputing. Wait briefly for it to finish.
|
||||
# Do not wait indefinitely; poll a couple times with small backoff.
|
||||
wait_deadline = time.time() + cache.ttl_lock
|
||||
while time.time() < wait_deadline:
|
||||
time.sleep(0.05)
|
||||
raw_payload = cache.redis.get(val_key)
|
||||
if raw_payload is not None:
|
||||
break
|
||||
|
||||
if raw_payload is None:
|
||||
# Still missing after waiting: compute synchronously as fallback to avoid returning None.
|
||||
value = func(*args, **kwargs)
|
||||
payload = _serialize(value)
|
||||
if cache.ttl_stale is None:
|
||||
cache.redis.set(val_key, payload)
|
||||
else:
|
||||
cache.redis.setex(val_key, cache.ttl_stale, payload)
|
||||
cache.redis.setex(fresh_key, cache.ttl_fresh, "1")
|
||||
return CachedValue(content=value, fresh=True)
|
||||
|
||||
# If we reach here, raw_payload is present (either from the start or after waiting)
|
||||
assert not isinstance(raw_payload, Awaitable)
|
||||
deserialized = _deserialize(raw_payload)
|
||||
|
||||
# If fresh marker missing => stale
|
||||
if not is_fresh:
|
||||
# Schedule background refresh; do not block caller.
|
||||
cache.tasks.enqueue(recompute_value, *args, **kwargs)
|
||||
|
||||
return CachedValue(content=deserialized, fresh=is_fresh)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
63
api/advent22_api/redis_cache/helpers.py
Normal file
63
api/advent22_api/redis_cache/helpers.py
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
import inspect
|
||||
import json
|
||||
from typing import Any, Callable
|
||||
|
||||
|
||||
def stable_repr(val: Any) -> str:
|
||||
"""Stable JSON representation for cache key components."""
|
||||
return json.dumps(val, sort_keys=True, default=str)
|
||||
|
||||
|
||||
def get_canonical_name(item: Any) -> str:
|
||||
"""Return canonical module.qualname for functions / callables."""
|
||||
module = getattr(
|
||||
item,
|
||||
"__module__",
|
||||
item.__class__.__module__,
|
||||
)
|
||||
qualname = getattr(
|
||||
item,
|
||||
"__qualname__",
|
||||
getattr(item, "__name__", item.__class__.__name__),
|
||||
)
|
||||
return f"{module}.{qualname}"
|
||||
|
||||
|
||||
def build_cache_key(
|
||||
func: Callable,
|
||||
prefix: str | None,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""
|
||||
Build a deterministic cache key for func called with args/kwargs.
|
||||
For bound methods, skips the first parameter if it's named 'self' or 'cls'.
|
||||
"""
|
||||
sig = inspect.signature(func)
|
||||
bound = sig.bind_partial(*args, **kwargs)
|
||||
bound.apply_defaults()
|
||||
|
||||
params = list(sig.parameters.values())
|
||||
arguments = list(bound.arguments.items())
|
||||
|
||||
# Detect methods: if first parameter name is 'self' or 'cls' and it's provided in bound args,
|
||||
# skip it when building the key.
|
||||
if params:
|
||||
first_name = params[0].name
|
||||
if first_name in ("self", "cls") and first_name in bound.arguments:
|
||||
arguments = arguments[1:]
|
||||
|
||||
arguments_fmt = [
|
||||
f"{name}={stable_repr(val)}"
|
||||
for name, val in sorted(arguments, key=lambda kv: kv[0])
|
||||
]
|
||||
|
||||
key_parts = [
|
||||
get_canonical_name(func),
|
||||
*arguments_fmt,
|
||||
]
|
||||
|
||||
if prefix is not None:
|
||||
key_parts = [prefix] + key_parts
|
||||
|
||||
return ":".join(key_parts)
|
||||
Loading…
Reference in a new issue