import functools import json import time from dataclasses import dataclass from typing import Any, Awaitable, Callable, Protocol from redis import Redis from .helpers import build_cache_key class JobsQueue(Protocol): def enqueue(self, task: Callable, *args: Any, **kwargs: Any) -> None: ... @dataclass(frozen=True, kw_only=True, slots=True) class RedisCache: """ Container for Redis-backed caching configuration. Attributes: redis: Redis client instance. tasks: Background job queue used to refresh stale entries. prefix: Optional key prefix (defaults to "cache"). ttl_fresh: TTL in seconds for the freshness marker (how long a value is considered fresh). ttl_stale: TTL in seconds for the cached payload. If None, payload does not expire. ttl_stale: TTL in seconds for the stampede protection lock. """ redis: Redis tasks: JobsQueue prefix: str | None = "cache" ttl_fresh: int = 600 ttl_stale: int | None = None ttl_lock: int = 5 @dataclass(frozen=True, kw_only=True, slots=True) class CachedValue[T]: """Wrapper for cached content plus freshness flag.""" content: T fresh: bool def cached(cache: RedisCache): """ Decorator factory to cache function results in Redis with a freshness marker. On miss, uses a short-lock (SETNX) to ensure only one process recomputes. When value is stale, returns the stale value and queues a background recompute. Accepts either: - a RedisCache instance, or - a callable taking the instance (self) and returning RedisCache. If a callable is given, the cache is resolved at call-time using the method's `self`. """ def decorator[T](func: Callable[..., T]) -> Callable[..., CachedValue[T]]: # Keys used in Redis: # - ":val" -> JSON-serialized value # - ":fresh" -> existence means fresh (string "1"), TTL = ttl_fresh # - ":lock" -> short-lived lock to prevent stampede def _redis_val_key(k: str) -> str: return f"{k}:val" def _redis_fresh_key(k: str) -> str: return f"{k}:fresh" def _redis_lock_key(k: str) -> str: return f"{k}:lock" def _serialize(v: Any) -> str: return json.dumps(v, default=str) def _deserialize(s: bytes | str | None) -> Any: if s is None: return None return json.loads(s) def recompute_value(*args: Any, **kwargs: Any) -> None: """ Recompute the function result and store in Redis. This function is intentionally designed to be enqueued into a background job queue. """ # Compute outside of any Redis lock to avoid holding locks during heavy computation. result = func(*args, **kwargs) full_key = build_cache_key(func, cache.prefix, *args, **kwargs) val_key = _redis_val_key(full_key) fresh_key = _redis_fresh_key(full_key) # Store payload (with optional ttl_stale) payload = _serialize(result) if cache.ttl_stale is None: # No expiry for payload cache.redis.set(val_key, payload) else: cache.redis.setex(val_key, cache.ttl_stale, payload) # Create freshness marker with ttl_fresh cache.redis.setex(fresh_key, cache.ttl_fresh, "1") # Ensure lock removed if present (best-effort) try: cache.redis.delete(_redis_lock_key(full_key)) except Exception: # swallow: background job should not crash for Redis delete errors pass @functools.wraps(func) def wrapper(*args: Any, **kwargs: Any) -> CachedValue[T]: """ Attempt to return cached value. If missing entirely, try to acquire a short lock and recompute synchronously; otherwise wait briefly for the recompute or return a fallback (None). If present but stale, return the stale value and enqueue background refresh. """ full_key = build_cache_key(func, cache.prefix, *args, **kwargs) val_key = _redis_val_key(full_key) fresh_key = _redis_fresh_key(full_key) lock_key = _redis_lock_key(full_key) # Try to get payload and freshness marker raw_payload = cache.redis.get(val_key) is_fresh = cache.redis.exists(fresh_key) == 1 if raw_payload is None: # Cache miss. Try to acquire lock to recompute synchronously. if cache.redis.setnx(lock_key, "1") == 1: # Ensure lock TTL so it doesn't persist forever cache.redis.expire(lock_key, cache.ttl_lock) try: # Recompute synchronously and store value = func(*args, **kwargs) payload = _serialize(value) if cache.ttl_stale is not None: cache.redis.setex(val_key, cache.ttl_stale, payload) else: cache.redis.set(val_key, payload) cache.redis.setex(fresh_key, cache.ttl_fresh, "1") return CachedValue(content=value, fresh=True) finally: # Release lock (best-effort) try: cache.redis.delete(lock_key) except Exception: pass else: # Another process is recomputing. Wait briefly for it to finish. # Do not wait indefinitely; poll a couple times with small backoff. wait_deadline = time.time() + cache.ttl_lock while time.time() < wait_deadline: time.sleep(0.05) raw_payload = cache.redis.get(val_key) if raw_payload is not None: break if raw_payload is None: # Still missing after waiting: compute synchronously as fallback to avoid returning None. value = func(*args, **kwargs) payload = _serialize(value) if cache.ttl_stale is None: cache.redis.set(val_key, payload) else: cache.redis.setex(val_key, cache.ttl_stale, payload) cache.redis.setex(fresh_key, cache.ttl_fresh, "1") return CachedValue(content=value, fresh=True) # If we reach here, raw_payload is present (either from the start or after waiting) assert not isinstance(raw_payload, Awaitable) deserialized = _deserialize(raw_payload) # If fresh marker missing => stale if not is_fresh: # Schedule background refresh; do not block caller. cache.tasks.enqueue(recompute_value, *args, **kwargs) return CachedValue(content=deserialized, fresh=is_fresh) return wrapper return decorator