Compare commits

...

2 commits

2 changed files with 126 additions and 5 deletions

View file

@ -0,0 +1,115 @@
import inspect
from base64 import urlsafe_b64encode
from collections.abc import Callable
from dataclasses import dataclass
from functools import wraps
from itertools import chain
from typing import Iterable, ParamSpec, TypeVar
from redis import Redis
@dataclass(frozen=True, kw_only=True, slots=True)
class Config:
redis: Redis
prefix: str | None = "cache"
ttl_fresh: int = 600
ttl_stale: int | None = None
def qualified_name(callable_obj) -> str:
# callable classes/instances
if hasattr(callable_obj, "__call__") and not inspect.isroutine(callable_obj):
# callable instance: use its class
cls = callable_obj.__class__
module = getattr(cls, "__module__", None)
qual = getattr(cls, "__qualname__", cls.__name__)
return f"{module}.{qual}" if module else qual
# functions, methods, builtins
# unwrap descriptors like staticmethod/classmethod
if isinstance(callable_obj, staticmethod | classmethod):
callable_obj = callable_obj.__func__
# bound method
if inspect.ismethod(callable_obj):
func = callable_obj.__func__
owner = getattr(func, "__qualname__", func.__name__).rsplit(".", 1)[0]
module = getattr(func, "__module__", None)
qual = f"{owner}.{func.__name__}"
return f"{module}.{qual}" if module else qual
# regular function or builtin
if inspect.isfunction(callable_obj) or inspect.isbuiltin(callable_obj):
module = getattr(callable_obj, "__module__", None)
qual = getattr(callable_obj, "__qualname__", callable_obj.__name__)
return f"{module}.{qual}" if module else qual
# fallback for other callables (functors, functools.partial, etc.)
try:
module = getattr(callable_obj, "__module__", None)
qual = getattr(callable_obj, "__qualname__", None) or getattr(
callable_obj, "__name__", type(callable_obj).__name__
)
return f"{module}.{qual}" if module else qual
except Exception:
return urlsafe_b64encode(repr(callable_obj).encode("utf-8")).decode("utf-8")
P = ParamSpec("P")
R = TypeVar("R")
def args_slice(func: Callable[P, R], *args: Iterable) -> tuple:
if hasattr(func, "__call__") and not inspect.isroutine(func):
return args[1:]
if isinstance(func, staticmethod | classmethod):
func = func.__func__
if inspect.ismethod(func):
return args[1:]
return tuple(*args)
def redis_cached(cfg: Config) -> Callable[[Callable[P, R]], Callable[P, R]]:
""" """
def cache_key(func: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> str:
"""Return a cache key for use with cached methods."""
kwargs_by_key = sorted(kwargs.items(), key=lambda kv: kv[0])
parts = chain(
(qualified_name(func),),
# positional args
(repr(arg) for arg in args_slice(func, args)),
# keyword args
(f"{k}={v!r}" for k, v in kwargs_by_key),
)
if cfg.prefix is not None:
parts = chain(
(cfg.prefix,),
parts,
)
return ":".join(parts)
def decorator(func: Callable[P, R]) -> Callable[P, R]:
""" """
@wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
key = cache_key(func, *args, **kwargs)
# pre-hook
result = func(*args, **kwargs)
# post-hook
return result
return wrapper
return decorator

View file

@ -1,6 +1,7 @@
from datetime import date from datetime import date
from enum import Enum
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel from pydantic import BaseModel
from advent22_api.core.helpers import EventDates from advent22_api.core.helpers import EventDates
@ -173,16 +174,21 @@ async def put_doors(
await cal_cfg.change(cfg) await cal_cfg.change(cfg)
class CredentialsName(str, Enum):
DAV = "dav"
UI = "ui"
@router.get("/credentials/{name}") @router.get("/credentials/{name}")
async def get_credentials( async def get_credentials(
name: str, name: CredentialsName,
_: None = Depends(require_admin), _: None = Depends(require_admin),
cfg: Config = Depends(get_config), cfg: Config = Depends(get_config),
) -> Credentials: ) -> Credentials:
if name == "dav": if name == CredentialsName.DAV:
return SETTINGS.webdav.auth return SETTINGS.webdav.auth
elif name == "ui": elif name == CredentialsName.UI:
return cfg.admin return cfg.admin
else: else:
return Credentials() raise HTTPException(status.HTTP_400_BAD_REQUEST)