187 lines
7.2 KiB
Python
187 lines
7.2 KiB
Python
|
|
import functools
|
||
|
|
import json
|
||
|
|
import time
|
||
|
|
from dataclasses import dataclass
|
||
|
|
from typing import Any, Awaitable, Callable, Protocol
|
||
|
|
|
||
|
|
from redis import Redis
|
||
|
|
|
||
|
|
from .helpers import build_cache_key
|
||
|
|
|
||
|
|
|
||
|
|
class JobsQueue(Protocol):
|
||
|
|
def enqueue(self, task: Callable, *args: Any, **kwargs: Any) -> None: ...
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass(frozen=True, kw_only=True, slots=True)
|
||
|
|
class RedisCache:
|
||
|
|
"""
|
||
|
|
Container for Redis-backed caching configuration.
|
||
|
|
|
||
|
|
Attributes:
|
||
|
|
redis: Redis client instance.
|
||
|
|
tasks: Background job queue used to refresh stale entries.
|
||
|
|
prefix: Optional key prefix (defaults to "cache").
|
||
|
|
ttl_fresh: TTL in seconds for the freshness marker (how long a value is considered fresh).
|
||
|
|
ttl_stale: TTL in seconds for the cached payload. If None, payload does not expire.
|
||
|
|
ttl_stale: TTL in seconds for the stampede protection lock.
|
||
|
|
"""
|
||
|
|
|
||
|
|
redis: Redis
|
||
|
|
tasks: JobsQueue
|
||
|
|
prefix: str | None = "cache"
|
||
|
|
ttl_fresh: int = 600
|
||
|
|
ttl_stale: int | None = None
|
||
|
|
ttl_lock: int = 5
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass(frozen=True, kw_only=True, slots=True)
|
||
|
|
class CachedValue[T]:
|
||
|
|
"""Wrapper for cached content plus freshness flag."""
|
||
|
|
|
||
|
|
content: T
|
||
|
|
fresh: bool
|
||
|
|
|
||
|
|
|
||
|
|
def cached(cache: RedisCache):
|
||
|
|
"""
|
||
|
|
Decorator factory to cache function results in Redis with a freshness marker.
|
||
|
|
On miss, uses a short-lock (SETNX) to ensure only one process recomputes.
|
||
|
|
When value is stale, returns the stale value and queues a background recompute.
|
||
|
|
|
||
|
|
Accepts either:
|
||
|
|
- a RedisCache instance, or
|
||
|
|
- a callable taking the instance (self) and returning RedisCache.
|
||
|
|
|
||
|
|
If a callable is given, the cache is resolved at call-time using the method's `self`.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def decorator[T](func: Callable[..., T]) -> Callable[..., CachedValue[T]]:
|
||
|
|
# Keys used in Redis:
|
||
|
|
# - "<key>:val" -> JSON-serialized value
|
||
|
|
# - "<key>:fresh" -> existence means fresh (string "1"), TTL = ttl_fresh
|
||
|
|
# - "<key>:lock" -> short-lived lock to prevent stampede
|
||
|
|
|
||
|
|
def _redis_val_key(k: str) -> str:
|
||
|
|
return f"{k}:val"
|
||
|
|
|
||
|
|
def _redis_fresh_key(k: str) -> str:
|
||
|
|
return f"{k}:fresh"
|
||
|
|
|
||
|
|
def _redis_lock_key(k: str) -> str:
|
||
|
|
return f"{k}:lock"
|
||
|
|
|
||
|
|
def _serialize(v: Any) -> str:
|
||
|
|
return json.dumps(v, default=str)
|
||
|
|
|
||
|
|
def _deserialize(s: bytes | str | None) -> Any:
|
||
|
|
if s is None:
|
||
|
|
return None
|
||
|
|
|
||
|
|
return json.loads(s)
|
||
|
|
|
||
|
|
def recompute_value(*args: Any, **kwargs: Any) -> None:
|
||
|
|
"""
|
||
|
|
Recompute the function result and store in Redis.
|
||
|
|
This function is intentionally designed to be enqueued into a background job queue.
|
||
|
|
"""
|
||
|
|
# Compute outside of any Redis lock to avoid holding locks during heavy computation.
|
||
|
|
result = func(*args, **kwargs)
|
||
|
|
full_key = build_cache_key(func, cache.prefix, *args, **kwargs)
|
||
|
|
val_key = _redis_val_key(full_key)
|
||
|
|
fresh_key = _redis_fresh_key(full_key)
|
||
|
|
|
||
|
|
# Store payload (with optional ttl_stale)
|
||
|
|
payload = _serialize(result)
|
||
|
|
if cache.ttl_stale is None:
|
||
|
|
# No expiry for payload
|
||
|
|
cache.redis.set(val_key, payload)
|
||
|
|
else:
|
||
|
|
cache.redis.setex(val_key, cache.ttl_stale, payload)
|
||
|
|
|
||
|
|
# Create freshness marker with ttl_fresh
|
||
|
|
cache.redis.setex(fresh_key, cache.ttl_fresh, "1")
|
||
|
|
|
||
|
|
# Ensure lock removed if present (best-effort)
|
||
|
|
try:
|
||
|
|
cache.redis.delete(_redis_lock_key(full_key))
|
||
|
|
except Exception:
|
||
|
|
# swallow: background job should not crash for Redis delete errors
|
||
|
|
pass
|
||
|
|
|
||
|
|
@functools.wraps(func)
|
||
|
|
def wrapper(*args: Any, **kwargs: Any) -> CachedValue[T]:
|
||
|
|
"""
|
||
|
|
Attempt to return cached value.
|
||
|
|
If missing entirely, try to acquire a short lock and recompute synchronously;
|
||
|
|
otherwise wait briefly for the recompute or return a fallback (None).
|
||
|
|
If present but stale, return the stale value and enqueue background refresh.
|
||
|
|
"""
|
||
|
|
full_key = build_cache_key(func, cache.prefix, *args, **kwargs)
|
||
|
|
val_key = _redis_val_key(full_key)
|
||
|
|
fresh_key = _redis_fresh_key(full_key)
|
||
|
|
lock_key = _redis_lock_key(full_key)
|
||
|
|
|
||
|
|
# Try to get payload and freshness marker
|
||
|
|
raw_payload = cache.redis.get(val_key)
|
||
|
|
is_fresh = cache.redis.exists(fresh_key) == 1
|
||
|
|
|
||
|
|
if raw_payload is None:
|
||
|
|
# Cache miss. Try to acquire lock to recompute synchronously.
|
||
|
|
if cache.redis.setnx(lock_key, "1") == 1:
|
||
|
|
# Ensure lock TTL so it doesn't persist forever
|
||
|
|
cache.redis.expire(lock_key, cache.ttl_lock)
|
||
|
|
|
||
|
|
try:
|
||
|
|
# Recompute synchronously and store
|
||
|
|
value = func(*args, **kwargs)
|
||
|
|
payload = _serialize(value)
|
||
|
|
if cache.ttl_stale is not None:
|
||
|
|
cache.redis.setex(val_key, cache.ttl_stale, payload)
|
||
|
|
else:
|
||
|
|
cache.redis.set(val_key, payload)
|
||
|
|
|
||
|
|
cache.redis.setex(fresh_key, cache.ttl_fresh, "1")
|
||
|
|
return CachedValue(content=value, fresh=True)
|
||
|
|
finally:
|
||
|
|
# Release lock (best-effort)
|
||
|
|
try:
|
||
|
|
cache.redis.delete(lock_key)
|
||
|
|
except Exception:
|
||
|
|
pass
|
||
|
|
else:
|
||
|
|
# Another process is recomputing. Wait briefly for it to finish.
|
||
|
|
# Do not wait indefinitely; poll a couple times with small backoff.
|
||
|
|
wait_deadline = time.time() + cache.ttl_lock
|
||
|
|
while time.time() < wait_deadline:
|
||
|
|
time.sleep(0.05)
|
||
|
|
raw_payload = cache.redis.get(val_key)
|
||
|
|
if raw_payload is not None:
|
||
|
|
break
|
||
|
|
|
||
|
|
if raw_payload is None:
|
||
|
|
# Still missing after waiting: compute synchronously as fallback to avoid returning None.
|
||
|
|
value = func(*args, **kwargs)
|
||
|
|
payload = _serialize(value)
|
||
|
|
if cache.ttl_stale is None:
|
||
|
|
cache.redis.set(val_key, payload)
|
||
|
|
else:
|
||
|
|
cache.redis.setex(val_key, cache.ttl_stale, payload)
|
||
|
|
cache.redis.setex(fresh_key, cache.ttl_fresh, "1")
|
||
|
|
return CachedValue(content=value, fresh=True)
|
||
|
|
|
||
|
|
# If we reach here, raw_payload is present (either from the start or after waiting)
|
||
|
|
assert not isinstance(raw_payload, Awaitable)
|
||
|
|
deserialized = _deserialize(raw_payload)
|
||
|
|
|
||
|
|
# If fresh marker missing => stale
|
||
|
|
if not is_fresh:
|
||
|
|
# Schedule background refresh; do not block caller.
|
||
|
|
cache.tasks.enqueue(recompute_value, *args, **kwargs)
|
||
|
|
|
||
|
|
return CachedValue(content=deserialized, fresh=is_fresh)
|
||
|
|
|
||
|
|
return wrapper
|
||
|
|
|
||
|
|
return decorator
|