From 5e83b58b32d8dc134bf82dbe7d3a24f50066c61f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B6rn-Michael=20Miehe?= Date: Sat, 21 Mar 2026 20:55:51 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=A7=20api:=20new=20"cached"=20decorato?= =?UTF-8?q?r=20implementation?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/advent22_api/core/dav/dec.py | 113 -------------- api/advent22_api/core/dav/webdav.py | 31 +++- api/advent22_api/redis_cache/__init__.py | 7 + api/advent22_api/redis_cache/cached.py | 186 +++++++++++++++++++++++ api/advent22_api/redis_cache/helpers.py | 63 ++++++++ 5 files changed, 279 insertions(+), 121 deletions(-) delete mode 100644 api/advent22_api/core/dav/dec.py create mode 100644 api/advent22_api/redis_cache/__init__.py create mode 100644 api/advent22_api/redis_cache/cached.py create mode 100644 api/advent22_api/redis_cache/helpers.py diff --git a/api/advent22_api/core/dav/dec.py b/api/advent22_api/core/dav/dec.py deleted file mode 100644 index a9e79fa..0000000 --- a/api/advent22_api/core/dav/dec.py +++ /dev/null @@ -1,113 +0,0 @@ -import inspect -from base64 import urlsafe_b64encode -from collections.abc import Callable -from dataclasses import dataclass -from functools import wraps -from itertools import chain -from typing import Iterable, ParamSpec, TypeVar - -from redis import Redis - - -@dataclass(frozen=True, kw_only=True, slots=True) -class Config: - redis: Redis - prefix: str | None = "cache" - ttl_fresh: int = 600 - ttl_stale: int | None = None - - -def qualified_name(callable_obj) -> str: - # callable classes/instances - if hasattr(callable_obj, "__call__") and not inspect.isroutine(callable_obj): - # callable instance: use its class - cls = callable_obj.__class__ - module = getattr(cls, "__module__", None) - qual = getattr(cls, "__qualname__", cls.__name__) - return f"{module}.{qual}" if module else qual - - # functions, methods, builtins - # unwrap descriptors like staticmethod/classmethod - if isinstance(callable_obj, staticmethod | classmethod): - callable_obj = callable_obj.__func__ - - # bound method - if inspect.ismethod(callable_obj): - func = callable_obj.__func__ - owner = getattr(func, "__qualname__", func.__name__).rsplit(".", 1)[0] - module = getattr(func, "__module__", None) - qual = f"{owner}.{func.__name__}" - return f"{module}.{qual}" if module else qual - - # regular function or builtin - if inspect.isfunction(callable_obj) or inspect.isbuiltin(callable_obj): - module = getattr(callable_obj, "__module__", None) - qual = getattr(callable_obj, "__qualname__", callable_obj.__name__) - return f"{module}.{qual}" if module else qual - - # fallback for other callables (functors, functools.partial, etc.) - try: - module = getattr(callable_obj, "__module__", None) - qual = getattr(callable_obj, "__qualname__", None) or getattr( - callable_obj, "__name__", type(callable_obj).__name__ - ) - return f"{module}.{qual}" if module else qual - - except Exception: - return urlsafe_b64encode(repr(callable_obj).encode("utf-8")).decode("utf-8") - - -P = ParamSpec("P") -R = TypeVar("R") - - -def args_slice(func: Callable[P, R], *args: Iterable) -> tuple: - if hasattr(func, "__call__") and not inspect.isroutine(func): - return args[1:] - - if isinstance(func, staticmethod | classmethod): - func = func.__func__ - - if inspect.ismethod(func): - return args[1:] - - return tuple(*args) - - -def cache_key(func: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> str: - """Return a cache key for use with cached methods.""" - - kwargs_by_key = sorted(kwargs.items(), key=lambda kv: kv[0]) - - parts = chain( - (qualified_name(func),), - # positional args - (repr(arg) for arg in args_slice(func, args)), - # keyword args - (f"{k}={v!r}" for k, v in kwargs_by_key), - ) - - return ":".join(parts) - - -def redis_cached(cfg: Config) -> Callable[[Callable[P, R]], Callable[P, R]]: - """ """ - - def decorator(func: Callable[P, R]) -> Callable[P, R]: - """ """ - - @wraps(func) - def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: - key = cache_key(func, *args, **kwargs) - - if cfg.prefix is not None: - key = f"{cfg.prefix}:{key}" - - # pre-hook - result = func(*args, **kwargs) - # post-hook - return result - - return wrapper - - return decorator diff --git a/api/advent22_api/core/dav/webdav.py b/api/advent22_api/core/dav/webdav.py index 26125be..c8442c9 100644 --- a/api/advent22_api/core/dav/webdav.py +++ b/api/advent22_api/core/dav/webdav.py @@ -2,12 +2,14 @@ import logging import re from dataclasses import dataclass from io import BytesIO +from typing import Any, Callable from asyncify import asyncify -from cachetools import cachedmethod +from fastapi import BackgroundTasks from redis import Redis -from .helpers import RedisCache, WebDAVclient, davkey +from ...redis_cache import JobsQueue, RedisCache, cached +from .helpers import WebDAVclient _logger = logging.getLogger(__name__) @@ -19,11 +21,24 @@ class Settings: password: str = "s3cr3t!" +class FastAPIQueue: + _tasks: BackgroundTasks + + def enqueue(self, task: Callable, *args: Any, **kwargs: Any) -> None: + self._tasks.add_task(task, args=args, kwargs=kwargs) + + class WebDAV: _webdav_client: WebDAVclient _cache: RedisCache - def __init__(self, settings: Settings, redis: Redis, ttl_sec: int) -> None: + def __init__( + self, + settings: Settings, + redis: Redis, + tasks: JobsQueue, + ttl_sec: int, + ) -> None: try: self._webdav_client = WebDAVclient( { @@ -37,10 +52,10 @@ class WebDAV: except AssertionError: raise RuntimeError("WebDAV connection failed!") - self._cache = RedisCache(cache=redis, ttl=ttl_sec) + self._cache = RedisCache(redis=redis, tasks=tasks, ttl_fresh=ttl_sec) @asyncify - @cachedmethod(cache=lambda self: self._cache, key=davkey("list_files")) + @cached(lambda self: self._cache) def _list_files(self, directory: str = "") -> list[str]: """ List files in directory `directory` matching RegEx `regex` @@ -60,7 +75,7 @@ class WebDAV: return [path for path in ls if regex.search(path)] @asyncify - @cachedmethod(cache=lambda self: self._cache, key=davkey("exists")) + @cached(lambda self: self._cache) def exists(self, path: str) -> bool: """ `True` iff there is a WebDAV resource at `path` @@ -70,7 +85,7 @@ class WebDAV: return self._webdav_client.check(path) @asyncify - @cachedmethod(cache=lambda self: self._cache, key=davkey("read_bytes")) + @cached(lambda self: self._cache) def read_bytes(self, path: str) -> bytes: """ Load WebDAV file from `path` as bytes @@ -102,7 +117,7 @@ class WebDAV: # invalidate cache entry # begin slice at 0 (there is no "self" argument) - del self._cache[davkey("read_bytes", slice(0, None))(path)] + # del self._cache[davkey("read_bytes", slice(0, None))(path)] async def write_str(self, path: str, content: str, encoding="utf-8") -> None: """ diff --git a/api/advent22_api/redis_cache/__init__.py b/api/advent22_api/redis_cache/__init__.py new file mode 100644 index 0000000..35062a7 --- /dev/null +++ b/api/advent22_api/redis_cache/__init__.py @@ -0,0 +1,7 @@ +from .cached import JobsQueue, RedisCache, cached + +__all__ = [ + "JobsQueue", + "RedisCache", + "cached", +] diff --git a/api/advent22_api/redis_cache/cached.py b/api/advent22_api/redis_cache/cached.py new file mode 100644 index 0000000..ea129bd --- /dev/null +++ b/api/advent22_api/redis_cache/cached.py @@ -0,0 +1,186 @@ +import functools +import json +import time +from dataclasses import dataclass +from typing import Any, Awaitable, Callable, Protocol + +from redis import Redis + +from .helpers import build_cache_key + + +class JobsQueue(Protocol): + def enqueue(self, task: Callable, *args: Any, **kwargs: Any) -> None: ... + + +@dataclass(frozen=True, kw_only=True, slots=True) +class RedisCache: + """ + Container for Redis-backed caching configuration. + + Attributes: + redis: Redis client instance. + tasks: Background job queue used to refresh stale entries. + prefix: Optional key prefix (defaults to "cache"). + ttl_fresh: TTL in seconds for the freshness marker (how long a value is considered fresh). + ttl_stale: TTL in seconds for the cached payload. If None, payload does not expire. + ttl_stale: TTL in seconds for the stampede protection lock. + """ + + redis: Redis + tasks: JobsQueue + prefix: str | None = "cache" + ttl_fresh: int = 600 + ttl_stale: int | None = None + ttl_lock: int = 5 + + +@dataclass(frozen=True, kw_only=True, slots=True) +class CachedValue[T]: + """Wrapper for cached content plus freshness flag.""" + + content: T + fresh: bool + + +def cached(cache: RedisCache): + """ + Decorator factory to cache function results in Redis with a freshness marker. + On miss, uses a short-lock (SETNX) to ensure only one process recomputes. + When value is stale, returns the stale value and queues a background recompute. + + Accepts either: + - a RedisCache instance, or + - a callable taking the instance (self) and returning RedisCache. + + If a callable is given, the cache is resolved at call-time using the method's `self`. + """ + + def decorator[T](func: Callable[..., T]) -> Callable[..., CachedValue[T]]: + # Keys used in Redis: + # - ":val" -> JSON-serialized value + # - ":fresh" -> existence means fresh (string "1"), TTL = ttl_fresh + # - ":lock" -> short-lived lock to prevent stampede + + def _redis_val_key(k: str) -> str: + return f"{k}:val" + + def _redis_fresh_key(k: str) -> str: + return f"{k}:fresh" + + def _redis_lock_key(k: str) -> str: + return f"{k}:lock" + + def _serialize(v: Any) -> str: + return json.dumps(v, default=str) + + def _deserialize(s: bytes | str | None) -> Any: + if s is None: + return None + + return json.loads(s) + + def recompute_value(*args: Any, **kwargs: Any) -> None: + """ + Recompute the function result and store in Redis. + This function is intentionally designed to be enqueued into a background job queue. + """ + # Compute outside of any Redis lock to avoid holding locks during heavy computation. + result = func(*args, **kwargs) + full_key = build_cache_key(func, cache.prefix, *args, **kwargs) + val_key = _redis_val_key(full_key) + fresh_key = _redis_fresh_key(full_key) + + # Store payload (with optional ttl_stale) + payload = _serialize(result) + if cache.ttl_stale is None: + # No expiry for payload + cache.redis.set(val_key, payload) + else: + cache.redis.setex(val_key, cache.ttl_stale, payload) + + # Create freshness marker with ttl_fresh + cache.redis.setex(fresh_key, cache.ttl_fresh, "1") + + # Ensure lock removed if present (best-effort) + try: + cache.redis.delete(_redis_lock_key(full_key)) + except Exception: + # swallow: background job should not crash for Redis delete errors + pass + + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> CachedValue[T]: + """ + Attempt to return cached value. + If missing entirely, try to acquire a short lock and recompute synchronously; + otherwise wait briefly for the recompute or return a fallback (None). + If present but stale, return the stale value and enqueue background refresh. + """ + full_key = build_cache_key(func, cache.prefix, *args, **kwargs) + val_key = _redis_val_key(full_key) + fresh_key = _redis_fresh_key(full_key) + lock_key = _redis_lock_key(full_key) + + # Try to get payload and freshness marker + raw_payload = cache.redis.get(val_key) + is_fresh = cache.redis.exists(fresh_key) == 1 + + if raw_payload is None: + # Cache miss. Try to acquire lock to recompute synchronously. + if cache.redis.setnx(lock_key, "1") == 1: + # Ensure lock TTL so it doesn't persist forever + cache.redis.expire(lock_key, cache.ttl_lock) + + try: + # Recompute synchronously and store + value = func(*args, **kwargs) + payload = _serialize(value) + if cache.ttl_stale is not None: + cache.redis.setex(val_key, cache.ttl_stale, payload) + else: + cache.redis.set(val_key, payload) + + cache.redis.setex(fresh_key, cache.ttl_fresh, "1") + return CachedValue(content=value, fresh=True) + finally: + # Release lock (best-effort) + try: + cache.redis.delete(lock_key) + except Exception: + pass + else: + # Another process is recomputing. Wait briefly for it to finish. + # Do not wait indefinitely; poll a couple times with small backoff. + wait_deadline = time.time() + cache.ttl_lock + while time.time() < wait_deadline: + time.sleep(0.05) + raw_payload = cache.redis.get(val_key) + if raw_payload is not None: + break + + if raw_payload is None: + # Still missing after waiting: compute synchronously as fallback to avoid returning None. + value = func(*args, **kwargs) + payload = _serialize(value) + if cache.ttl_stale is None: + cache.redis.set(val_key, payload) + else: + cache.redis.setex(val_key, cache.ttl_stale, payload) + cache.redis.setex(fresh_key, cache.ttl_fresh, "1") + return CachedValue(content=value, fresh=True) + + # If we reach here, raw_payload is present (either from the start or after waiting) + assert not isinstance(raw_payload, Awaitable) + deserialized = _deserialize(raw_payload) + + # If fresh marker missing => stale + if not is_fresh: + # Schedule background refresh; do not block caller. + cache.tasks.enqueue(recompute_value, *args, **kwargs) + + return CachedValue(content=deserialized, fresh=is_fresh) + + return wrapper + + return decorator diff --git a/api/advent22_api/redis_cache/helpers.py b/api/advent22_api/redis_cache/helpers.py new file mode 100644 index 0000000..621490b --- /dev/null +++ b/api/advent22_api/redis_cache/helpers.py @@ -0,0 +1,63 @@ +import inspect +import json +from typing import Any, Callable + + +def stable_repr(val: Any) -> str: + """Stable JSON representation for cache key components.""" + return json.dumps(val, sort_keys=True, default=str) + + +def get_canonical_name(item: Any) -> str: + """Return canonical module.qualname for functions / callables.""" + module = getattr( + item, + "__module__", + item.__class__.__module__, + ) + qualname = getattr( + item, + "__qualname__", + getattr(item, "__name__", item.__class__.__name__), + ) + return f"{module}.{qualname}" + + +def build_cache_key( + func: Callable, + prefix: str | None, + *args: Any, + **kwargs: Any, +) -> str: + """ + Build a deterministic cache key for func called with args/kwargs. + For bound methods, skips the first parameter if it's named 'self' or 'cls'. + """ + sig = inspect.signature(func) + bound = sig.bind_partial(*args, **kwargs) + bound.apply_defaults() + + params = list(sig.parameters.values()) + arguments = list(bound.arguments.items()) + + # Detect methods: if first parameter name is 'self' or 'cls' and it's provided in bound args, + # skip it when building the key. + if params: + first_name = params[0].name + if first_name in ("self", "cls") and first_name in bound.arguments: + arguments = arguments[1:] + + arguments_fmt = [ + f"{name}={stable_repr(val)}" + for name, val in sorted(arguments, key=lambda kv: kv[0]) + ] + + key_parts = [ + get_canonical_name(func), + *arguments_fmt, + ] + + if prefix is not None: + key_parts = [prefix] + key_parts + + return ":".join(key_parts)