from __future__ import annotations
import asyncio
from abc import abstractmethod
from asyncio import create_task
from contextlib import AsyncExitStack, _GeneratorContextManager, asynccontextmanager
from enum import Enum
from os import PathLike
from typing_extensions import (
TYPE_CHECKING,
AsyncGenerator,
Callable,
Generic,
Iterable,
LiteralString,
NoReturn,
Self,
Sequence,
TypeVar,
cast,
final,
)
from ..ctx import EventOrigin, FlowCtx, OutSrcFilterCtx
from ..exceptions import AdapterError
from ..io.base import (
AbstractInSource,
AbstractOutSource,
AbstractSource,
EchoPacketT,
InPacketT,
InSourceT,
OutPacketT,
OutSourceT,
)
from ..log.reflect import logger
from ..mixin import HookMixin
from ..typ.base import AsyncCallable, P, SyncOrAsyncCallable
from ..typ.cls import BetterABC
from .content import Content, TextContent
from .model import ActionHandle, ActionHandleGroup, ActionT, EchoT, Event, EventT
if TYPE_CHECKING:
from ..bot.dispatch import Dispatcher
_OUT_SRC_FILTER_CTX = OutSrcFilterCtx()
_FLOW_CTX = FlowCtx()
[文档]
class AbstractEventFactory(BetterABC, Generic[InPacketT, EventT]):
"""抽象事件工厂类"""
[文档]
@abstractmethod
async def create(self, packet: InPacketT) -> EventT:
"""将 :class:`.InPacket` 对象转换为 :class:`.Event` 对象的方法
:param packet: 输入包
:return: 事件
"""
raise NotImplementedError
EventFactoryT = TypeVar("EventFactoryT", bound=AbstractEventFactory)
[文档]
class AbstractOutputFactory(BetterABC, Generic[OutPacketT, ActionT]):
"""抽象输出工厂类"""
[文档]
@abstractmethod
async def create(self, action: ActionT) -> OutPacketT:
"""将 :class:`.Action` 对象转换为 :class:`.OutPacket` 对象的方法
:param packet: 行为
:return: 输出包
"""
raise NotImplementedError
OutputFactoryT = TypeVar("OutputFactoryT", bound=AbstractOutputFactory)
[文档]
class AbstractEchoFactory(BetterABC, Generic[EchoPacketT, EchoT]):
"""抽象回应工厂类"""
[文档]
@abstractmethod
async def create(self, packet: EchoPacketT) -> EchoT | None:
"""将 :class:`.EchoPacket` 对象转换为 :class:`.Echo` 对象的方法
:param packet: 回应包
:return: 回应或空值
"""
raise NotImplementedError
EchoFactoryT = TypeVar("EchoFactoryT", bound=AbstractEchoFactory)
[文档]
class AdapterLifeSpan(Enum):
"""适配器生命周期阶段的枚举"""
#: 适配器将一个事件投入分发器之前(提供首参数类型为 Event 的依赖注入,你可以自行细化类型)
BEFORE_EVENT_HANDLE = "beh"
#: 适配器将一个行为投入执行之前(提供首参数类型为 Action 的依赖注入,你可以自行细化类型)
BEFORE_ACTION_EXEC = "bae"
#: 适配器启动完成之后(此时所有同协议的源已经开始工作)
STARTED = "sta"
#: 适配器停止即将发生前(出现异常可能不会调用此类型)
CLOSE = "clo"
#: 适配器停止完成后(包含任何情况导致的停止)
STOPPED = "sto"
[文档]
class Adapter(
HookMixin[AdapterLifeSpan],
Generic[
EventFactoryT,
OutputFactoryT,
EchoFactoryT,
ActionT,
InSourceT,
OutSourceT,
],
BetterABC,
):
"""适配器基类
:ivar LiteralString protocol: 适配器所使用的协议
"""
def __init__(
self,
protocol: LiteralString,
event_factory: EventFactoryT,
output_factory: OutputFactoryT,
echo_factory: EchoFactoryT,
) -> None:
super().__init__(hook_type=AdapterLifeSpan, hook_tag=protocol)
self.protocol = protocol
self.in_srcs: set[InSourceT] = set()
self.out_srcs: set[OutSourceT] = set()
self.dispatcher: "Dispatcher"
self._event_factory = event_factory
self._output_factory = output_factory
self._echo_factory = echo_factory
self._inited = False
self.__mark_repeatable_hooks__(
AdapterLifeSpan.BEFORE_EVENT_HANDLE, AdapterLifeSpan.BEFORE_ACTION_EXEC
)
[文档]
@final
def get_isrcs(self, filter: Callable[[InSourceT], bool]) -> set[InSourceT]:
"""获取与当前适配器匹配的所有输入源
:param filter: 过滤函数,为 `True` 时保留输入源
:return: 输入源的集合
"""
return set(src for src in self.in_srcs if filter(src))
[文档]
@final
def get_osrcs(self, filter: Callable[[OutSourceT], bool]) -> set[OutSourceT]:
"""获取与当前适配器匹配的所有输出源
:param filter: 过滤函数,为 `True` 时保留输出源
:return: 输出源的集合
"""
return set(src for src in self.out_srcs if filter(src))
@final
async def __adapter_input_loop__(self, src: InSourceT) -> NoReturn:
while True:
try:
packet = await src.input()
event: Event = await self._event_factory.create(packet)
EventOrigin.set_origin(event, EventOrigin(self, src))
await self._hook_bus.emit(AdapterLifeSpan.BEFORE_EVENT_HANDLE, True, args=(event,))
self.dispatcher.broadcast(event)
except Exception:
logger.generic_exc(
f"适配器 {self} 处理输入源 {src} 时发生异常",
obj={
"in_factory": self._event_factory,
"dispatcher": self.dispatcher,
}
| locals(),
)
@asynccontextmanager
@final
async def __wait_close_hook__(self) -> AsyncGenerator[None, None]:
try:
yield
finally:
await self._hook_bus.emit(AdapterLifeSpan.CLOSE, True)
@asynccontextmanager
@final
async def __adapter_launch__(self) -> AsyncGenerator[Self, None]:
if self._inited:
raise AdapterError(f"适配器 {self} 已在运行,不能重复启动")
try:
async with AsyncExitStack() as stack:
out_src_ts = tuple(
create_task(stack.enter_async_context(src)) for src in self.out_srcs
)
if len(out_src_ts):
await asyncio.wait(out_src_ts)
# 防止有些输入输出源在第一次启动失败后,又触发一次下面的启动
in_src_set = cast(set[AbstractSource], self.in_srcs) - cast(
set[AbstractSource], self.out_srcs
)
in_src_ts = tuple(
create_task(stack.enter_async_context(src))
for src in in_src_set
if not src.opened()
)
if len(in_src_ts):
await asyncio.wait(in_src_ts)
for src in self.in_srcs:
create_task(self.__adapter_input_loop__(src))
await stack.enter_async_context(self.__wait_close_hook__())
self._inited = True
await self._hook_bus.emit(AdapterLifeSpan.STARTED)
yield self
finally:
if self._inited:
await self._hook_bus.emit(AdapterLifeSpan.STOPPED, True)
[文档]
@final
def filter_out(
self, filter: Callable[[OutSourceT], bool]
) -> _GeneratorContextManager[Callable[[OutSourceT], bool]]:
"""上下文管理器,提供由 `filter` 控制输出的输出上下文
:param filter: 过滤函数,为 `True` 时允许该输出源输出
"""
return _OUT_SRC_FILTER_CTX.unfold(filter)
[文档]
async def call_output(self, action: ActionT) -> ActionHandleGroup:
"""输出行为,返回 :class:`.ActionHandleGroup`
适配器开发者的适配器子类可以重写此方法,以实现自定义功能
但子类实现中必须调用原始方法: `super().call_output(...)`
:param action: 行为
:return: :class:`.ActionHandleGroup` 对象
"""
osrcs: Iterable[OutSourceT]
filter = _OUT_SRC_FILTER_CTX.try_get()
cur_isrc: AbstractInSource | None
if event := _FLOW_CTX.try_get_event():
cur_isrc = EventOrigin.get_origin(event).in_src
else:
cur_isrc = None
if filter is not None:
osrcs = (osrc for osrc in self.out_srcs if filter(osrc))
elif isinstance(cur_isrc, AbstractOutSource) and cur_isrc in self.out_srcs:
osrcs = (cast(OutSourceT, cur_isrc),)
else:
osrcs = self.out_srcs if len(self.out_srcs) else ()
await self._hook_bus.emit(AdapterLifeSpan.BEFORE_ACTION_EXEC, True, args=(action,))
hs: tuple[ActionHandle, ...] = tuple(
ActionHandle(action, osrc, self._output_factory, self._echo_factory) for osrc in osrcs
)
return ActionHandleGroup(*hs)
@abstractmethod
async def __send_text__(self, *texts: str | TextContent) -> ActionHandleGroup:
"""输出文本
抽象方法。所有适配器子类应该实现此方法
:param texts: 文本
:return: :class:`.ActionHandleGroup` 对象
"""
raise NotImplementedError
async def __send_media__(
self,
name: str,
raw: bytes | None = None,
url: str | None = None,
mimetype: str | None = None,
) -> ActionHandleGroup:
"""输出多媒体内容
建议所有适配器子类重写此方法,否则回退到基类实现:仅使用 :func:`send_text` 输出相关提示信息
:param name: 多媒体内容的名称
:param raw: 多媒体内容的二进制内容
:param url: 多媒体内容的网络地址(和 `raw` 参数二选一)
:param mimetype: 多媒体内容的 mimetype,为空则根据 `name` 自动检测
:return: :class:`.ActionHandleGroup` 对象
"""
return await self.__send_text__(
f"[melobot media: {name if url is None else name + ' at ' + url}]"
)
async def __send_image__(
self,
name: str,
raw: bytes | None = None,
url: str | None = None,
mimetype: str | None = None,
) -> ActionHandleGroup:
"""输出图像内容
建议所有适配器子类重写此方法,否则回退到基类实现:仅使用 :func:`send_text` 输出相关提示信息
:param name: 图像内容的名称
:param raw: 图像内容的二进制内容
:param url: 图像内容的网络地址(和 `raw` 参数二选一)
:param mimetype: 图像内容的 mimetype,为空则根据 `name` 自动检测
:return: :class:`.ActionHandleGroup` 对象
"""
return await self.__send_text__(
f"[melobot image: {name if url is None else name + ' at ' + url}]"
)
async def __send_audio__(
self,
name: str,
raw: bytes | None = None,
url: str | None = None,
mimetype: str | None = None,
) -> ActionHandleGroup:
"""输出音频内容
建议所有适配器子类重写此方法,否则回退到基类实现:仅使用 :func:`send_text` 输出相关提示信息
:param name: 音频内容的名称
:param raw: 音频内容的二进制内容
:param url: 音频内容的网络地址(和 `raw` 参数二选一)
:param mimetype: 音频内容的 mimetype,为空则根据 `name` 自动检测
:return: :class:`.ActionHandleGroup` 对象
"""
return await self.__send_text__(
f"[melobot audio: {name if url is None else name + ' at ' + url}]"
)
async def __send_voice__(
self,
name: str,
raw: bytes | None = None,
url: str | None = None,
mimetype: str | None = None,
) -> ActionHandleGroup:
"""输出语音内容
建议所有适配器子类重写此方法,否则回退到基类实现:仅使用 :func:`send_text` 输出相关提示信息
:param name: 语音内容的名称
:param raw: 语音内容的二进制内容
:param url: 语音内容的网络地址(和 `raw` 参数二选一)
:param mimetype: 语音内容的 mimetype,为空则根据 `name` 自动检测
:return: :class:`.ActionHandleGroup` 对象
"""
return await self.__send_text__(
f"[melobot voice: {name if url is None else name + ' at ' + url}]"
)
async def __send_video__(
self,
name: str,
raw: bytes | None = None,
url: str | None = None,
mimetype: str | None = None,
) -> ActionHandleGroup:
"""输出视频内容
建议所有适配器子类重写此方法,否则回退到基类实现:仅使用 :func:`send_text` 输出相关提示信息
:param name: 视频内容的名称
:param raw: 视频内容的二进制内容
:param url: 视频内容的网络地址(和 `raw` 参数二选一)
:param mimetype: 视频内容的 mimetype,为空则根据 `name` 自动检测
:return: :class:`.ActionHandleGroup` 对象
"""
return await self.__send_text__(
f"[melobot video: {name if url is None else name + ' at ' + url}]"
)
async def __send_file__(self, name: str, path: str | PathLike[str]) -> ActionHandleGroup:
"""输出文件
建议所有适配器子类重写此方法,否则回退到基类实现:仅使用 :func:`send_text` 输出相关提示信息
:param name: 文件名
:param path: 文件路径
:return: :class:`.ActionHandleGroup` 对象
"""
return await self.__send_text__(f"[melobot file: {name} at {path}]")
async def __send_refer__(
self, event: Event, contents: Sequence[Content] | None = None
) -> ActionHandleGroup:
"""输出对过往事件的引用
建议所有适配器子类重写此方法,否则回退到基类实现:仅使用 :func:`send_text` 输出相关提示信息
:param event: 过往的事件对象
:param contents: 附加的通用内容序列
:return: :class:`.ActionHandleGroup` 对象
"""
return await self.__send_text__(f"[melobot refer: {event.__class__.__name__}({event.id})]")
async def __send_resource__(self, name: str, url: str) -> ActionHandleGroup:
"""输出网络资源
建议所有适配器子类重写此方法,否则回退到基类实现:仅使用 :func:`send_text` 输出相关提示信息
:param name: 网络资源名称
:param url: 网络资源的 url
:return: :class:`.ActionHandleGroup` 对象
"""
return await self.__send_text__(f"[melobot resource: {name} at {url}]")
@property
def on_before_event_handle(
self,
) -> Callable[[SyncOrAsyncCallable[P, None]], AsyncCallable[P, None]]:
"""给适配器注册 :obj:`.AdapterLifeSpan.BEFORE_EVENT_HANDLE` 阶段 hook 的装饰器"""
return self.on(AdapterLifeSpan.BEFORE_EVENT_HANDLE)
@property
def on_before_action_exec(
self,
) -> Callable[[SyncOrAsyncCallable[P, None]], AsyncCallable[P, None]]:
"""给适配器注册 :obj:`.AdapterLifeSpan.BEFORE_ACTION_EXEC` 阶段 hook 的装饰器"""
return self.on(AdapterLifeSpan.BEFORE_ACTION_EXEC)
@property
def on_started(
self,
) -> Callable[[SyncOrAsyncCallable[P, None]], AsyncCallable[P, None]]:
"""给适配器注册 :obj:`.AdapterLifeSpan.STARTED` 阶段 hook 的装饰器"""
return self.on(AdapterLifeSpan.STARTED)
@property
def on_close(
self,
) -> Callable[[SyncOrAsyncCallable[P, None]], AsyncCallable[P, None]]:
"""给适配器注册 :obj:`.AdapterLifeSpan.CLOSE` 阶段 hook 的装饰器"""
return self.on(AdapterLifeSpan.CLOSE)
@property
def on_stopped(
self,
) -> Callable[[SyncOrAsyncCallable[P, None]], AsyncCallable[P, None]]:
"""给适配器注册 :obj:`.AdapterLifeSpan.STOPPED` 阶段 hook 的装饰器"""
return self.on(AdapterLifeSpan.STOPPED)
AdapterT = TypeVar("AdapterT", bound=Adapter)