diff --git a/api/advent22_api/core/dav/dec.py b/api/advent22_api/core/dav/dec.py new file mode 100644 index 0000000..116122c --- /dev/null +++ b/api/advent22_api/core/dav/dec.py @@ -0,0 +1,115 @@ +import inspect +from base64 import urlsafe_b64encode +from collections.abc import Callable +from dataclasses import dataclass +from functools import wraps +from itertools import chain +from typing import Iterable, ParamSpec, TypeVar + +from redis import Redis + + +@dataclass(frozen=True, kw_only=True, slots=True) +class Config: + redis: Redis + prefix: str | None = "cache" + ttl_fresh: int = 600 + ttl_stale: int | None = None + + +def qualified_name(callable_obj) -> str: + # callable classes/instances + if hasattr(callable_obj, "__call__") and not inspect.isroutine(callable_obj): + # callable instance: use its class + cls = callable_obj.__class__ + module = getattr(cls, "__module__", None) + qual = getattr(cls, "__qualname__", cls.__name__) + return f"{module}.{qual}" if module else qual + + # functions, methods, builtins + # unwrap descriptors like staticmethod/classmethod + if isinstance(callable_obj, staticmethod | classmethod): + callable_obj = callable_obj.__func__ + + # bound method + if inspect.ismethod(callable_obj): + func = callable_obj.__func__ + owner = getattr(func, "__qualname__", func.__name__).rsplit(".", 1)[0] + module = getattr(func, "__module__", None) + qual = f"{owner}.{func.__name__}" + return f"{module}.{qual}" if module else qual + + # regular function or builtin + if inspect.isfunction(callable_obj) or inspect.isbuiltin(callable_obj): + module = getattr(callable_obj, "__module__", None) + qual = getattr(callable_obj, "__qualname__", callable_obj.__name__) + return f"{module}.{qual}" if module else qual + + # fallback for other callables (functors, functools.partial, etc.) + try: + module = getattr(callable_obj, "__module__", None) + qual = getattr(callable_obj, "__qualname__", None) or getattr( + callable_obj, "__name__", type(callable_obj).__name__ + ) + return f"{module}.{qual}" if module else qual + + except Exception: + return urlsafe_b64encode(repr(callable_obj).encode("utf-8")).decode("utf-8") + + +P = ParamSpec("P") +R = TypeVar("R") + + +def args_slice(func: Callable[P, R], *args: Iterable) -> tuple: + if hasattr(func, "__call__") and not inspect.isroutine(func): + return args[1:] + + if isinstance(func, staticmethod | classmethod): + func = func.__func__ + + if inspect.ismethod(func): + return args[1:] + + return tuple(*args) + + +def redis_cached(cfg: Config) -> Callable[[Callable[P, R]], Callable[P, R]]: + """ """ + + def cache_key(func: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> str: + """Return a cache key for use with cached methods.""" + + kwargs_by_key = sorted(kwargs.items(), key=lambda kv: kv[0]) + + parts = chain( + (qualified_name(func),), + # positional args + (repr(arg) for arg in args_slice(func, args)), + # keyword args + (f"{k}={v!r}" for k, v in kwargs_by_key), + ) + + if cfg.prefix is not None: + parts = chain( + (cfg.prefix,), + parts, + ) + + return ":".join(parts) + + def decorator(func: Callable[P, R]) -> Callable[P, R]: + """ """ + + @wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + key = cache_key(func, *args, **kwargs) + + # pre-hook + result = func(*args, **kwargs) + # post-hook + return result + + return wrapper + + return decorator