melobot.di 源代码

from __future__ import annotations

from abc import abstractmethod
from asyncio import Lock
from dataclasses import dataclass
from functools import partial, wraps
from inspect import Parameter, isawaitable, signature
from types import BuiltinFunctionType, FunctionType, LambdaType, MethodType

from typing_extensions import (
    TYPE_CHECKING,
    Annotated,
    Any,
    Callable,
    Generic,
    Sequence,
    cast,
    get_args,
    get_origin,
    overload,
)

from .ctx import BotCtx, EventOrigin, FlowCtx, ParseArgsCtx, SessionCtx, get_logger_type
from .exceptions import DependInitError, DependRuntimeError
from .typ.base import AsyncCallable, P, SyncOrAsyncCallable, T, U, is_subhint, is_type
from .typ.cls import BetterABC
from .utils.base import to_async
from .utils.common import get_obj_name

if TYPE_CHECKING:
    from .adapter.base import Adapter


class DependNotMatched(BaseException):
    def __init__(
        self, msg: str, func_name: str, arg_name: str, real_type: type | None, hint: Any
    ) -> None:
        super().__init__(msg)
        self.func_name = func_name
        self.arg_name = arg_name
        self.real_type = real_type
        self.hint = hint


[文档] class Depends(Generic[T, U]): _EMPTY = object() @overload def __init__( self, dep: SyncOrAsyncCallable[[], T] | Depends[T], sub_getter: None = None, cache: bool = False, recursive: bool = True, ) -> None: ... @overload def __init__( self, dep: SyncOrAsyncCallable[[], U] | Depends[U], sub_getter: SyncOrAsyncCallable[[U], T], cache: bool = False, recursive: bool = True, ) -> None: ...
[文档] def __init__( self, dep: SyncOrAsyncCallable[[], Any] | Depends[Any], sub_getter: SyncOrAsyncCallable[[Any], Any] | None = None, cache: bool = False, recursive: bool = True, ) -> None: """初始化一个依赖项 :param dep: 依赖来源(可调用对象,异步可调用对象,或依赖项) :param sub_getter: 子获取器(可调用对象,异步可调用对象或空),在获得依赖之后,于其上继续获取 :param cache: 是否启用缓存 :param recursive: 是否启用递归满足(默认启用,如果 `dep` 和 `sub_getter` 为可调用对象,会自动被 {func}`.inject_deps` 装饰;关闭可节约性能) """ super().__init__() self.ref: Depends[T] | None self.getter: AsyncCallable[[], T] | None if isinstance(dep, Depends): self.ref = dep self.getter = None else: self.ref = None if recursive: self.getter = inject_deps(dep) # type: ignore[arg-type] else: self.getter = to_async(dep) # type: ignore[arg-type] if sub_getter is None: self.sub_getter = None elif recursive: self.sub_getter = inject_deps(sub_getter) else: self.sub_getter = to_async(sub_getter) self._lock = Lock() if cache else None self._cached: Any = self._EMPTY
def __repr__(self) -> str: getter_str = f"getter={self.getter}" if self.getter is not None else "" ref_str = f"ref={self.ref}" if self.ref is not None else "" return f"{self.__class__.__name__}({ref_str if ref_str != '' else getter_str})" async def _get(self, dep_scope: dict[Depends, Any]) -> T: if self.getter is not None: val = await self.getter() else: ref = cast(Depends[T], self.ref) val = dep_scope.get(ref, self._EMPTY) if val is self._EMPTY: val = await ref.fulfill(dep_scope) if self.sub_getter is not None: val = await self.sub_getter(val) return val async def fulfill(self, dep_scope: dict[Depends, Any]) -> T: if self._lock is None: val = await self._get(dep_scope) elif self._cached is not self._EMPTY: val = self._cached else: async with self._lock: if self._cached is self._EMPTY: self._cached = await self._get(dep_scope) val = self._cached dep_scope[self] = val return val
[文档] class CbDepends(Depends, BetterABC, Generic[T]): """回调型依赖 依赖项,但在依赖满足后,执行内部的回调 """
[文档] def __init__( self, dep: SyncOrAsyncCallable[[], Any], cache: bool = False, recursive: bool = False, ) -> None: super().__init__(dep, cache=cache, recursive=recursive)
[文档] @abstractmethod async def deps_callback(self, val: Any) -> T: """所有子类必须实现该抽象方法 :param val: 依赖项被满足后的值 :return: 处理后的值,作为依赖项最终的值 """ return cast(T, val)
async def fulfill(self, dep_scope: dict[Depends, Any]) -> T: val = await super().fulfill(dep_scope) new_val = await self.deps_callback(val) return new_val
class AutoDepends(CbDepends): def __init__(self, func: Callable, name: str, hint: Any) -> None: self.hint = hint self.func = func self.func_name = get_obj_name(func, otype="callable") self.arg_name = name self._match_event = False if get_origin(hint) is Annotated: args = get_args(hint) if not len(args): raise DependRuntimeError("可依赖注入的函数若使用 Annotated 注解,必须附加元数据") self.metadatas = args else: self.metadatas = () self.orig_getter: SyncOrAsyncCallable[[], Any] | None = None if is_subhint(hint, FlowCtx().get_event_type()): self.orig_getter = FlowCtx().get_event elif is_subhint(hint, BotCtx().get_type()): self.orig_getter = BotCtx().get elif is_subhint(hint, _get_adapter_type()): self.orig_getter = cast(Callable[[], Any], partial(_adapter_get, self, hint)) elif is_subhint(hint, get_logger_type()): self.orig_getter = BotCtx().get_logger elif is_subhint(hint, FlowCtx().get_store_type()): self.orig_getter = FlowCtx().get_store elif is_subhint(hint, SessionCtx().get_session_type()): self.orig_getter = SessionCtx().get elif is_subhint(hint, SessionCtx().get_store_type()): self.orig_getter = SessionCtx().get_store elif is_subhint(hint, SessionCtx().get_rule_type()): self.orig_getter = SessionCtx().get_rule elif is_subhint(hint, ParseArgsCtx().get_args_type()): self.orig_getter = ParseArgsCtx().get elif is_subhint(hint, FlowCtx().get_records_type()): self.orig_getter = FlowCtx().get_records for data in self.metadatas: if isinstance(data, MatchEvent): self._match_event = True if self.orig_getter is None: raise DependInitError( f"函数 {self.func_name} 的参数 {name} 提供的类型注解" f" {hint} 无法用于注入任何依赖,请检查是否有误" ) for data in self.metadatas: if isinstance(data, Reflect): self.orig_getter = cast(Callable[[], Any], partial(Reflection, self.orig_getter)) break super().__init__(self.orig_getter, cache=False, recursive=False) def _unmatch_exc(self, real_type: Any) -> DependNotMatched: if real_type is Depends._EMPTY: return DependNotMatched( f"函数 {self.func_name} 的参数 {self.arg_name} 对应的依赖项," f"在当前上下文中不存在", self.func_name, self.arg_name, None, self.hint, ) else: return DependNotMatched( f"函数 {self.func_name} 的参数 {self.arg_name} 与注解 {self.hint} 不匹配", self.func_name, self.arg_name, real_type, self.hint, ) async def deps_callback(self, val: Any) -> Any: ret = val if isinstance(val, Reflection): val = val.__origin__ for data in self.metadatas: if isinstance(data, Exclude): if any(isinstance(val, t) for t in data.types): raise self._unmatch_exc(type(val)) if not is_type(val, self.hint): raise self._unmatch_exc(type(val)) return ret def _get_adapter_type() -> type["Adapter"]: from .adapter.base import Adapter return Adapter def _adapter_get(deps: AutoDepends, hint: Any) -> "Adapter": if not deps._match_event: if get_origin(hint) is Annotated: args = get_args(hint) if not len(args): raise DependRuntimeError("可依赖注入的函数若使用 Annotated 注解,必须附加元数据") adapter_type = args[0] else: adapter_type = hint adapter = BotCtx().get().get_adapter(adapter_type) if adapter is None: raise deps._unmatch_exc(Depends._EMPTY) from None return cast("Adapter", adapter) else: flow_ctx = FlowCtx() try: event = flow_ctx.get_event() return EventOrigin.get_origin(event).adapter except flow_ctx.lookup_exc_cls: raise deps._unmatch_exc(Depends._EMPTY) from None
[文档] @dataclass class Exclude: """数据类。`types` 指定的类别会在依赖注入时被排除 .. code:: python # 假设有 B 继承于 A, C 继承于 A, D 继承于 A # 表示不包括 B 和 C 类别的 A 的所有子类型: NewTypeHint = Annotated[A, Exclude(types=[B, C])] # 当然,依然会兼容 A 类型 """ types: Sequence[type]
[文档] @dataclass class Reflect: """数据类。指定不直接获取当前依赖项,而是获取对应的一个反射代理 这适用于希望依赖会随着上下文改变,而动态变化的情况。例如动态引用会话流程中的事件对象 .. code:: python # 注入一个依赖时进一步包装为反射依赖 ReflectedEvent = Annotated[Event, Reflect()] event_proxy: RefelectedEvent # 就像使用 event 一样使用 event_proxy event_proxy.attr_xxx event_proxy.method_xxx() # 不过 event_proxy 不是完美的代理 # 因此 isinstance 类似的操作,使用 __origin__ 获取原始对象 isinstance(event_proxy.__origin__, SomeEventType) # 或者是作为运行逻辑未知的函数的参数 dont_know_what_this_do(event_proxy.__origin__) """
[文档] @dataclass class MatchEvent: """数据类。指定从当前事件的上下文中获取依赖 默认情况下,获取 Adapter 依赖都会直接尝试遍历所有可能的对象。 即尽最大可能获取指定类型的对象。但有时需要实现这样的需求: .. code:: python # 假设 bot 已经加载了两个适配器:ObAdapter 和 XxAdapter from melobot.handle import on_event # 期待事件来自 ObAdapter 时,调用这个 @on_event() async def on_onebot_event(adapter: ObAdapter) -> None: ... # 期待事件来自 XxAdapter 时,调用这个 @on_event() async def on_xx_event(adapter: XxAdapter) -> None: ... # 但默认的逻辑是:bot 只要加载了对应的适配器,依赖就可以满足 # 所以实际上他们都会被调用,没有任何区分效果 # 使用 MatchEvent 来改变依赖获取的逻辑: # 事件的来源适配器必须和 MatchEvent 中指定的类型匹配,依赖才能满足 @on_event() async def on_onebot_event(adapter: Annotated[ObAdapter, MatchEvent()]) -> None: ... @on_event() async def on_xx_event(adapter: Annotated[XxAdapter, MatchEvent()]) -> None: ... """
class Reflection: def __init__(self, getter: Callable[[], Any]) -> None: super().__setattr__("__obj_getter__", getter) @property def __origin__(self) -> Any: return self.__obj_getter__() def __getattr__(self, name: str) -> Any: getter = self.__obj_getter__ if name.startswith("_"): raise AttributeError(f"在反射对象上,不允许访问名称以 _ 开头的属性:{name}") return getattr(getter(), name) def __setattr__(self, name: str, value: Any) -> Any: getter = self.__obj_getter__ if name == "__obj_getter__": return getter if name.startswith("_"): raise AttributeError(f"在反射对象上,不允许修改名称以 _ 开头的属性:{name}") return setattr(getter(), name, value) _DI_DEFAULTS = "MELOBOT_DI_TUPLE" _DI_KW_DEFAULTS = "MELOBOT_DI_DICT" _DI_INJECTED = "MELOBOT_DI_INJECTED" def _init_auto_deps(func: Callable[P, T], allow_manual_arg: bool) -> None: # 无需再考虑对于多层装饰的兼容,signature 方法会获取原始签名 # 若装饰过程在逻辑上改变了签名,这种情况属于错误用例,无需兼容 sign = signature(func) empty = Parameter.empty args: list[Any] = [] kwargs = {} for name, param in sign.parameters.items(): # 可变位置参数或可变关键字参数时,无需任何操作 if param.kind in (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD): continue # 有默认值的情况,保存默认值并跳过 if param.default is not empty: if param.kind is Parameter.KEYWORD_ONLY: kwargs[name] = param.default else: args.append(param.default) continue # 没有默认值,没有注解的参数,跳过 # 无法识别依赖类型,只能通过手动传参提供 if param.annotation is empty: continue # 剩余情况需要注入自动依赖 try: dep = None if get_origin(param.annotation) is Annotated: anno_args = get_args(param.annotation) if not len(anno_args): raise DependRuntimeError( "可依赖注入的函数若使用 Annotated 注解,必须附加元数据" ) for v in anno_args: if isinstance(v, Depends): dep = v break if dep is None: dep = AutoDepends(func, name, param.annotation) except DependInitError: # 如果允许手动传参,无法识别的依赖显然是允许的 if allow_manual_arg: continue raise if param.kind is Parameter.KEYWORD_ONLY: kwargs[name] = dep else: # 按照参数顺序遍历,添加到默认值列表是保序的 args.append(dep) func.__dict__[_DI_DEFAULTS] = tuple(args) func.__dict__[_DI_KW_DEFAULTS] = kwargs func.__dict__[_DI_INJECTED] = True
[文档] def inject_deps( injectee: SyncOrAsyncCallable[..., T], manual_arg: bool = False, avoid_repeat: bool = False ) -> AsyncCallable[..., T]: """依赖注入标记装饰器,标记当前对象需要被依赖注入 可以标记的对象类别有: 同步函数,异步函数,匿名函数,同步生成器函数,异步生成器函数,实例方法、类方法、静态方法 :param injectee: 需要被注入的对象 :param manual_arg: 当前对象标记需要依赖注入后,是否还可以给某些参数手动传参 :param avoid_repeat: 是否避免在多层装饰时重复注入。检查内层装饰链,若内层装饰已有注入则放弃本次注入。 但需要所有内层装饰使用 :func:`functools.wraps` 进行包装,否则无法检测 :return: 异步可调用对象,但保留原始参数和返回值签名 """ @wraps(injectee) async def inject_deps_wrapped(*args: Any, **kwargs: Any) -> T: defaults: tuple[Any] = injectee.__dict__[_DI_DEFAULTS] kw_defaults: dict[str, Any] = injectee.__dict__[_DI_KW_DEFAULTS] # 模拟已有参数替换默认值的情况 _args = [*args, *defaults[len(args) :]] _kwargs = kw_defaults.copy() | kwargs dep_scope: dict[Depends, Any] = {} for idx, _ in enumerate(_args): elem = _args[idx] if isinstance(elem, Depends): _args[idx] = await elem.fulfill(dep_scope) for k in _kwargs: elem = _kwargs[k] if isinstance(elem, Depends): _kwargs[k] = await elem.fulfill(dep_scope) try: ret = injectee(*_args, **_kwargs) # type: ignore[arg-type] except TypeError as e: fname = get_obj_name(injectee, otype="callable") raise DependRuntimeError( f"依赖注入下的函数 {fname} 调用时发生错误:{e}。" "可能是参数存在问题:传参个数不匹配,或提供了错误的类型注解;" "或是函数内部逻辑错误。" ) from None if isawaitable(ret): return await ret return ret if avoid_repeat: f = injectee while True: if hasattr(f, _DI_INJECTED): return to_async(injectee) elif hasattr(f, "__wrapped__"): f = f.__wrapped__ else: break if isinstance(injectee, (FunctionType, MethodType)): _init_auto_deps(injectee, manual_arg) return inject_deps_wrapped if isinstance(injectee, LambdaType): injectee.__dict__[_DI_DEFAULTS] = injectee.__defaults__ or () injectee.__dict__[_DI_KW_DEFAULTS] = injectee.__kwdefaults__ or {} injectee.__dict__[_DI_INJECTED] = True return inject_deps_wrapped if isinstance(injectee, partial): injectee.__dict__[_DI_DEFAULTS] = injectee.args or () injectee.__dict__[_DI_KW_DEFAULTS] = injectee.keywords or {} injectee.__dict__[_DI_INJECTED] = True return inject_deps_wrapped if isinstance(injectee, BuiltinFunctionType): raise DependInitError(f"内建函数 {injectee} 不支持依赖注入") raise DependInitError( f"{injectee} 对象不属于以下类别中的任何一种:" "{同步函数,异步函数,匿名函数,partial 对象,同步生成器函数,异步生成器函数," "实例方法、类方法、静态方法},因此不能被注入依赖" )