advent22/api/advent22_api/redis_cache/cached.py

187 lines
7.2 KiB
Python
Raw Normal View History

import functools
import json
import time
from dataclasses import dataclass
from typing import Any, Awaitable, Callable, Protocol
from redis import Redis
from .helpers import build_cache_key
class JobsQueue(Protocol):
def enqueue(self, task: Callable, *args: Any, **kwargs: Any) -> None: ...
@dataclass(frozen=True, kw_only=True, slots=True)
class RedisCache:
"""
Container for Redis-backed caching configuration.
Attributes:
redis: Redis client instance.
tasks: Background job queue used to refresh stale entries.
prefix: Optional key prefix (defaults to "cache").
ttl_fresh: TTL in seconds for the freshness marker (how long a value is considered fresh).
ttl_stale: TTL in seconds for the cached payload. If None, payload does not expire.
ttl_stale: TTL in seconds for the stampede protection lock.
"""
redis: Redis
tasks: JobsQueue
prefix: str | None = "cache"
ttl_fresh: int = 600
ttl_stale: int | None = None
ttl_lock: int = 5
@dataclass(frozen=True, kw_only=True, slots=True)
class CachedValue[T]:
"""Wrapper for cached content plus freshness flag."""
content: T
fresh: bool
def cached(cache: RedisCache):
"""
Decorator factory to cache function results in Redis with a freshness marker.
On miss, uses a short-lock (SETNX) to ensure only one process recomputes.
When value is stale, returns the stale value and queues a background recompute.
Accepts either:
- a RedisCache instance, or
- a callable taking the instance (self) and returning RedisCache.
If a callable is given, the cache is resolved at call-time using the method's `self`.
"""
def decorator[T](func: Callable[..., T]) -> Callable[..., CachedValue[T]]:
# Keys used in Redis:
# - "<key>:val" -> JSON-serialized value
# - "<key>:fresh" -> existence means fresh (string "1"), TTL = ttl_fresh
# - "<key>:lock" -> short-lived lock to prevent stampede
def _redis_val_key(k: str) -> str:
return f"{k}:val"
def _redis_fresh_key(k: str) -> str:
return f"{k}:fresh"
def _redis_lock_key(k: str) -> str:
return f"{k}:lock"
def _serialize(v: Any) -> str:
return json.dumps(v, default=str)
def _deserialize(s: bytes | str | None) -> Any:
if s is None:
return None
return json.loads(s)
def recompute_value(*args: Any, **kwargs: Any) -> None:
"""
Recompute the function result and store in Redis.
This function is intentionally designed to be enqueued into a background job queue.
"""
# Compute outside of any Redis lock to avoid holding locks during heavy computation.
result = func(*args, **kwargs)
full_key = build_cache_key(func, cache.prefix, *args, **kwargs)
val_key = _redis_val_key(full_key)
fresh_key = _redis_fresh_key(full_key)
# Store payload (with optional ttl_stale)
payload = _serialize(result)
if cache.ttl_stale is None:
# No expiry for payload
cache.redis.set(val_key, payload)
else:
cache.redis.setex(val_key, cache.ttl_stale, payload)
# Create freshness marker with ttl_fresh
cache.redis.setex(fresh_key, cache.ttl_fresh, "1")
# Ensure lock removed if present (best-effort)
try:
cache.redis.delete(_redis_lock_key(full_key))
except Exception:
# swallow: background job should not crash for Redis delete errors
pass
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> CachedValue[T]:
"""
Attempt to return cached value.
If missing entirely, try to acquire a short lock and recompute synchronously;
otherwise wait briefly for the recompute or return a fallback (None).
If present but stale, return the stale value and enqueue background refresh.
"""
full_key = build_cache_key(func, cache.prefix, *args, **kwargs)
val_key = _redis_val_key(full_key)
fresh_key = _redis_fresh_key(full_key)
lock_key = _redis_lock_key(full_key)
# Try to get payload and freshness marker
raw_payload = cache.redis.get(val_key)
is_fresh = cache.redis.exists(fresh_key) == 1
if raw_payload is None:
# Cache miss. Try to acquire lock to recompute synchronously.
if cache.redis.setnx(lock_key, "1") == 1:
# Ensure lock TTL so it doesn't persist forever
cache.redis.expire(lock_key, cache.ttl_lock)
try:
# Recompute synchronously and store
value = func(*args, **kwargs)
payload = _serialize(value)
if cache.ttl_stale is not None:
cache.redis.setex(val_key, cache.ttl_stale, payload)
else:
cache.redis.set(val_key, payload)
cache.redis.setex(fresh_key, cache.ttl_fresh, "1")
return CachedValue(content=value, fresh=True)
finally:
# Release lock (best-effort)
try:
cache.redis.delete(lock_key)
except Exception:
pass
else:
# Another process is recomputing. Wait briefly for it to finish.
# Do not wait indefinitely; poll a couple times with small backoff.
wait_deadline = time.time() + cache.ttl_lock
while time.time() < wait_deadline:
time.sleep(0.05)
raw_payload = cache.redis.get(val_key)
if raw_payload is not None:
break
if raw_payload is None:
# Still missing after waiting: compute synchronously as fallback to avoid returning None.
value = func(*args, **kwargs)
payload = _serialize(value)
if cache.ttl_stale is None:
cache.redis.set(val_key, payload)
else:
cache.redis.setex(val_key, cache.ttl_stale, payload)
cache.redis.setex(fresh_key, cache.ttl_fresh, "1")
return CachedValue(content=value, fresh=True)
# If we reach here, raw_payload is present (either from the start or after waiting)
assert not isinstance(raw_payload, Awaitable)
deserialized = _deserialize(raw_payload)
# If fresh marker missing => stale
if not is_fresh:
# Schedule background refresh; do not block caller.
cache.tasks.enqueue(recompute_value, *args, **kwargs)
return CachedValue(content=deserialized, fresh=is_fresh)
return wrapper
return decorator