melobot.di 源代码

from __future__ import annotations

from abc import abstractmethod
from asyncio import Lock
from collections import deque
from dataclasses import dataclass
from functools import partial, wraps
from inspect import Parameter, isawaitable, signature, unwrap
from sys import version_info
from types import BuiltinFunctionType, FunctionType, LambdaType

from typing_extensions import (
    TYPE_CHECKING,
    Annotated,
    Any,
    Callable,
    Generic,
    Sequence,
    cast,
    get_args,
    get_origin,
)

from .ctx import BotCtx, EventOrigin, FlowCtx, LoggerCtx, ParseArgsCtx, SessionCtx
from .exceptions import DependBindError, DependInitError
from .typ._enum import VoidType
from .typ.base import AsyncCallable, P, SyncOrAsyncCallable, T, is_subhint, is_type
from .typ.cls import BetterABC
from .utils.base import to_async
from .utils.common import get_obj_name

if TYPE_CHECKING:
    from .adapter.base import Adapter


class DependNotMatched(BaseException):
    def __init__(self, msg: str, func_name: str, arg_name: str, real_type: type, hint: Any) -> None:
        super().__init__(msg)
        self.func_name = func_name
        self.arg_name = arg_name
        self.real_type = real_type
        self.hint = hint


[文档] class Depends(Generic[T]):
[文档] def __init__( self, dep: SyncOrAsyncCallable[[], T] | Depends[T], sub_getter: SyncOrAsyncCallable[[T], T] | None = None, cache: bool = False, recursive: bool = True, ) -> None: """初始化一个依赖项 :param dep: 依赖来源(可调用对象,异步可调用对象,或依赖项) :param sub_getter: 子获取器(可调用对象,异步可调用对象或空),在获得依赖之后,于其上继续获取 :param cache: 是否启用缓存 :param recursive: 是否启用递归满足(默认启用,如果当前依赖来源中存在依赖项,会被递归满足;关闭可节约性能) """ super().__init__() self.ref: Depends[T] | None self.getter: AsyncCallable[[], T] | None if isinstance(dep, Depends): self.ref = dep self.getter = None else: self.ref = None if recursive: self.getter = inject_deps(dep) else: self.getter = to_async(dep) if sub_getter is None: self.sub_getter = None elif recursive: self.sub_getter = inject_deps(sub_getter) else: self.sub_getter = to_async(sub_getter) self._lock = Lock() if cache else None self._cached: Any = VoidType.VOID
def __repr__(self) -> str: getter_str = f"getter={self.getter}" if self.getter is not None else "" ref_str = f"ref={self.ref}" if self.ref is not None else "" return f"{self.__class__.__name__}({ref_str if ref_str != '' else getter_str})" async def _get(self, dep_scope: dict[Depends, Any]) -> T: val: T | VoidType if self.getter is not None: val = await self.getter() else: ref = cast(Depends[T], self.ref) val = dep_scope.get(ref, VoidType.VOID) if val is VoidType.VOID: val = await ref.fulfill(dep_scope) if self.sub_getter is not None: val = await self.sub_getter(val) return val async def fulfill(self, dep_scope: dict[Depends, Any]) -> T: if self._lock is None: val = await self._get(dep_scope) elif self._cached is not VoidType.VOID: val = self._cached else: async with self._lock: if self._cached is VoidType.VOID: self._cached = await self._get(dep_scope) val = self._cached dep_scope[self] = val return val
class AutoDepends(Depends): def __init__(self, func: Callable, name: str, hint: Any) -> None: self.hint = hint self.func = func self.func_name = get_obj_name(func, otype="callable") self.arg_name = name if get_origin(hint) is Annotated: args = get_args(hint) if not len(args): raise DependInitError("可依赖注入的函数若使用 Annotated 注解,必须附加元数据") self.metadatas = args else: self.metadatas = () self.orig_getter: SyncOrAsyncCallable[[], Any] | None = None if is_subhint(hint, FlowCtx().get_event_type()): self.orig_getter = FlowCtx().get_event elif is_subhint(hint, BotCtx().get_type()): self.orig_getter = BotCtx().get elif is_subhint(hint, _get_adapter_type()): self.orig_getter = cast(Callable[[], Any], partial(_adapter_get, self, hint)) elif is_subhint(hint, LoggerCtx().get_type()): self.orig_getter = LoggerCtx().get elif is_subhint(hint, FlowCtx().get_store_type()): self.orig_getter = FlowCtx().get_store elif is_subhint(hint, SessionCtx().get_session_type()): self.orig_getter = SessionCtx().get elif is_subhint(hint, SessionCtx().get_store_type()): self.orig_getter = SessionCtx().get_store elif is_subhint(hint, SessionCtx().get_rule_type()): self.orig_getter = SessionCtx().get_rule elif is_subhint(hint, ParseArgsCtx().get_args_type()): self.orig_getter = ParseArgsCtx().get for data in self.metadatas: if isinstance(data, CustomLogger): self.orig_getter = cast(Callable[[], Any], partial(_custom_logger_get, hint, data)) break if self.orig_getter is None: raise DependInitError( f"函数 {self.func_name} 的参数 {name} 提供的类型注解" f" {hint} 无法用于注入任何依赖,请检查是否有误" ) for data in self.metadatas: if isinstance(data, Reflect): self.orig_getter = cast(Callable[[], Any], partial(Reflection, self.orig_getter)) break super().__init__(self.orig_getter, sub_getter=None, cache=False, recursive=False) def _unmatch_exc(self, real_type: Any) -> DependNotMatched: return DependNotMatched( f"函数 {self.func_name} 的参数 {self.arg_name} " f"与注解 {self.hint} 不匹配", self.func_name, self.arg_name, real_type, self.hint, ) def _match_check(self, val: Any) -> None: for data in self.metadatas: if isinstance(data, Exclude): if any(isinstance(val, t) for t in data.types): raise self._unmatch_exc(type(val)) for data in self.metadatas: if isinstance(data, CustomLogger): return if not is_type(val, self.hint): raise self._unmatch_exc(type(val)) async def fulfill(self, dep_scope: dict[Depends, Any]) -> Any: val = await super().fulfill(dep_scope) if isinstance(val, Reflection): inner_val = val.__origin__ if isawaitable(inner_val): raise AttributeError(f"异步依赖项不能通过 {Reflect.__name__} 创建反射依赖") self._match_check(inner_val) return val self._match_check(val) return val def _get_adapter_type() -> type["Adapter"]: from .adapter.base import Adapter return Adapter def _adapter_get(deps: AutoDepends, hint: Any) -> "Adapter": flow_ctx = FlowCtx() try: event = flow_ctx.get_event() return EventOrigin.get_origin(event).adapter except flow_ctx.lookup_exc_cls: adapter = BotCtx().get().get_adapter(hint) if adapter is None: raise deps._unmatch_exc(VoidType) from None return adapter def _custom_logger_get(hint: Any, data: CustomLogger) -> Any: val = LoggerCtx().get() if not is_type(val, hint): val = data.getter() return val
[文档] @dataclass class Exclude: """数据类。`types` 指定的类别会在依赖注入时被排除 .. code:: python # 假设有继承关系 A <- B, A <- C, A <- D # 表示 A 中不包括 B 和 C 类别的所有子类型,当然,还是会兼容 A 类型本身 NewTypeHint = Annotated[A, Exclude(types=[B, C])] """ types: Sequence[type]
[文档] @dataclass class CustomLogger: """数据类。`getter` 参数会用于指定类别日志器不存在时的获取方法 .. code:: python # 如果 bot 设置的 logger 是 MyLogger 类型,则成功依赖注入 # 否则使用 getter 获取一个日志器 NewLoggerHint = Annotated[MyLogger, CustomLogger(getter=MyLogger)] """ getter: Callable[[], Any]
[文档] @dataclass class Reflect: """数据类。指定不直接获取当前依赖项,而是获取对应的一个反射代理 这适用于希望依赖会随着上下文改变,而动态变化的情况。例如动态引用会话流程中的事件对象 .. code:: python # 注入一个依赖时进一步包装为反射依赖 event_proxy = Annotated[Event, Reflect()] # 就像使用 event 一样使用 event_proxy event_proxy.attr_xxx event_proxy.method_xxx() # 不过 event_proxy 不是完美的代理 # 因此 isinstance 类似的操作,使用 __origin__ 获取原始对象 isinstance(event_proxy.__origin__, SomeEventType) # 或者是作为运行逻辑未知的函数的参数 dont_know_what_this_do(event_proxy.__origin__) """
class Reflection: def __init__(self, getter: Callable[[], Any]) -> None: super().__setattr__("__obj_getter__", getter) @property def __origin__(self) -> Any: return self.__obj_getter__() def __getattr__(self, name: str) -> Any: getter = self.__obj_getter__ if name == "__obj_getter__": return getter if name.startswith("_"): raise AttributeError(f"在反射对象上,不允许访问名称以 _ 开头的属性:{name}") return getattr(getter(), name) def __setattr__(self, name: str, value: Any) -> Any: getter = self.__obj_getter__ if name == "__obj_getter__": return getter if name.startswith("_"): raise AttributeError(f"在反射对象上,不允许修改名称以 _ 开头的属性:{name}") return setattr(getter(), name, value) def _init_auto_deps(func: Callable[P, T], allow_manual_arg: bool) -> None: try: sign = signature(func) except ValueError as e: tip = "no signature found for builtin" if str(e).startswith(tip) and version_info <= (3, 10): raise DependInitError( f"内建函数 {func} 在 python <= 3.10 的版本中,无法进行依赖注入" ) from None raise empty = Parameter.empty origin_f = unwrap(func, stop=lambda f: hasattr(f, "__signature__")) ds = deque(origin_f.__defaults__) if origin_f.__defaults__ is not None else deque() kwds = origin_f.__kwdefaults__ if origin_f.__kwdefaults__ is not None else {} nargs: list[Any] = [] for name, param in sign.parameters.items(): if param.kind in (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD): continue if param.default is not empty: if name in kwds: pass else: ds.popleft() nargs.append(param.default) continue if param.annotation is empty: continue try: dep = AutoDepends(func, name, param.annotation) except DependInitError: if allow_manual_arg: continue raise if dep is None: continue if param.kind is Parameter.KEYWORD_ONLY: kwds[name] = dep else: nargs.append(dep) origin_f.__defaults__ = tuple(nargs) if len(nargs) else None origin_f.__kwdefaults__ = kwds if len(kwds) else None # type: ignore[assignment] def _get_bound_args( func: Callable, /, *args: Any, **kwargs: Any ) -> tuple[list[Any], dict[str, Any]]: sign = signature(func) try: bind = sign.bind(*args, **kwargs) except TypeError as e: fname = get_obj_name(func, otype="callable") raise DependBindError( f"依赖注入匹配失败。匹配函数 {fname} 的参数时发生错误:{e}。" "这可能是因为传参个数不匹配,或提供了错误的类型注解" ) from None bind.apply_defaults() return list(bind.args), bind.kwargs
[文档] class DependsHook(Depends[T], BetterABC): """依赖钩子 包装一个依赖项,依赖满足后内部的 hook 将会执行 """
[文档] def __init__( self, dep: SyncOrAsyncCallable[[], T], cache: bool = False, recursive: bool = False, ) -> None: super().__init__(dep, cache=cache, recursive=recursive)
[文档] @abstractmethod async def deps_callback(self, val: T) -> None: """所有依赖钩子子类必须实现该抽象方法 :param val: 依赖项被满足后的值 """ raise NotImplementedError
async def fulfill(self, dep_scope: dict[Depends, Any]) -> T: val = await super().fulfill(dep_scope) await self.deps_callback(val) return val
[文档] def inject_deps( injectee: SyncOrAsyncCallable[..., T], manual_arg: bool = False ) -> AsyncCallable[..., T]: """依赖注入标记装饰器,标记当前对象需要被依赖注入 可以标记的对象类别有: 同步函数,异步函数,匿名函数,同步生成器函数,异步生成器函数,实例方法、类方法、静态方法 :param injectee: 需要被注入的对象 :param manual_arg: 当前对象标记需要依赖注入后,是否还可以给某些参数手动传参 :return: 异步可调用对象,但保留原始参数和返回值签名 """ @wraps(injectee) async def inject_deps_wrapped(*args: Any, **kwargs: Any) -> T: _args, _kwargs = _get_bound_args(injectee, *args, **kwargs) dep_scope: dict[Depends, Any] = {} for idx, _ in enumerate(_args): elem = _args[idx] if isinstance(elem, Depends): _args[idx] = await elem.fulfill(dep_scope) for idx, k in enumerate(_kwargs.keys()): elem = _kwargs[k] if isinstance(elem, Depends): _kwargs[k] = await elem.fulfill(dep_scope) ret = injectee(*_args, **_kwargs) # type: ignore[arg-type] if isawaitable(ret): return await ret return ret if isinstance(injectee, (FunctionType, BuiltinFunctionType)): _init_auto_deps(injectee, manual_arg) return inject_deps_wrapped if isinstance(injectee, LambdaType): return inject_deps_wrapped raise DependInitError( f"{injectee} 对象不属于以下类别中的任何一种:" "{同步函数,异步函数,匿名函数,同步生成器函数,异步生成器函数," "实例方法、类方法、静态方法},因此不能被注入依赖" )