melobot.protocols.onebot.v11.io.ws 源代码

from __future__ import annotations

import asyncio
import http
import json
import time
from asyncio import Future

from typing_extensions import TYPE_CHECKING, Any, Callable, Coroutine
from websockets.asyncio.server import ServerConnection
from websockets.http11 import Request, Response

from melobot._hook import HookBus
from melobot.io import SourceLifeSpan
from melobot.log import logger
from melobot.utils import get_id

from ..const import ACTION_TYPE_KEY_NAME
from .base import BaseIOSource, InstCounter
from .packet import (
    DownstreamCallInPacket,
    EchoPacket,
    EchoToDownstream,
    EventToDownstream,
    InPacket,
    OutPacket,
    ShareToDownstreamInPacket,
    UpstreamRetInPacket,
)
from .ws_impl import WSClientImpl, WSServerImpl

if TYPE_CHECKING:
    from .ws_rproxy import RProxyWSClient, RProxyWSServer


class GenericIOLayer:
    def __init__(self, rproxy: "RProxyWSClient | RProxyWSServer | None") -> None:
        self._in_buf: asyncio.Queue[InPacket] = asyncio.Queue()
        self._out_buf: asyncio.Queue[OutPacket] = asyncio.Queue()
        self._pre_send_time = time.time_ns()
        self._echo_table: dict[str, tuple[str, Future[EchoPacket]]] = {}
        self._echo_mapping: dict[str, tuple[str, str]] = {}

        self._rproxy = rproxy
        if self._rproxy is not None:
            self._rproxy.bind_src(self)

        # 在继承具体的实现类后拥有这些属性
        self.name: str
        self.cd_time: float
        self._hook_bus: HookBus
        self._start: Callable[[], Coroutine[Any, Any, None]]
        self._stop: Callable[[], Coroutine[Any, Any, None]]
        self._opened: asyncio.Event

    async def _on_received(self, raw: str | bytes) -> None:
        if raw == "":
            return

        raw_dic = json.loads(raw)
        if "post_type" in raw_dic:
            if self._rproxy is None:
                self._in_buf.put_nowait(InPacket(time=raw_dic["time"], data=raw_dic))
            else:
                share_pak = ShareToDownstreamInPacket(time=raw_dic["time"], data=raw_dic)
                self._in_buf.put_nowait(share_pak)
                asyncio.create_task(self._to_downstream(share_pak.to_downstream))
            return

        echo_id = raw_dic.get("echo")
        if echo_id in (None, ""):
            return

        if self._rproxy is not None and echo_id in self._echo_mapping:
            # 构造由下游调用引发的结果返回事件
            event_data: dict[str, Any] = {}
            event_data["post_type"] = "upstream_ret"
            event_data["time"] = int(time.time_ns() / 1e9)
            event_data["self_id"] = -1

            # 查表获得下游发起调用时使用的回声字段
            calling_type, down_seen_echo = self._echo_mapping.pop(echo_id)
            raw_dic[ACTION_TYPE_KEY_NAME] = calling_type
            raw_dic["echo"] = down_seen_echo
            event_data["ret"] = raw_dic

            ret_pak = UpstreamRetInPacket(time=event_data["time"], data=event_data)
            # 放入缓存,随后传递给适配器形成上游返回事件
            self._in_buf.put_nowait(ret_pak)
            asyncio.create_task(self._to_downstream(ret_pak.to_downstream))
        else:
            # 在这里处理 melobot 内部发起的调用的返回结果
            action_type, fut = self._echo_table.pop(echo_id)
            fut.set_result(
                EchoPacket(
                    time=int(time.time_ns() / 1e9),
                    data=raw_dic,
                    ok=raw_dic["status"] == "ok",
                    status=raw_dic["retcode"],
                    action_type=action_type,
                )
            )

    async def _on_get_output(self) -> str | bytes | None:
        out_packet = await self._out_buf.get()
        wait_time = self.cd_time - ((time.time_ns() - self._pre_send_time) / 1e9)
        await asyncio.sleep(wait_time)
        return out_packet.data

    async def _on_sent(self, out: str | bytes) -> None:
        self._pre_send_time = time.time_ns()

    async def _on_relinked(self) -> None:
        await self._hook_bus.emit(SourceLifeSpan.RESTARTED, False)

    async def input(self) -> InPacket:
        return await self._in_buf.get()

    async def output(self, packet: OutPacket) -> EchoPacket:
        if self._out_buf.qsize() > 100:
            logger.warning(
                f"{self.name} 输出缓冲区溢出,开始丢弃操作请求。请保证连接畅通或减少操作请求频率"
            )
            raise RuntimeError("输出缓冲区溢出,操作请求被丢弃")
        self._out_buf.put_nowait(packet)
        if packet.echo_id is None:
            return EchoPacket(noecho=True)

        fut: Future[EchoPacket] = asyncio.get_running_loop().create_future()
        if len(self._echo_table) > 256:
            logger.warning(
                f"{self.name} echo 标识映射表溢出,开始丢弃操作请求。请保证连接畅通或减少操作请求频率"
            )
            raise RuntimeError("echo 标识映射表溢出,操作请求被丢弃")
        self._echo_table[packet.echo_id] = (packet.action_type, fut)
        return await fut

    async def _to_downstream(
        self, fut: Future[EventToDownstream] | Future[EchoToDownstream]
    ) -> None:
        try:
            ret = await fut
            if not ret.is_forbidden() and self._rproxy is not None:
                self._rproxy.to_downstream(ret.get_json())
        except Exception:
            logger.generic_exc(f"{self.name} 传递数据给下游时发生异常", obj={"fut": fut})

    async def _to_upstream(self, raw: str | bytes) -> None:
        try:
            if raw == "":
                return
            raw_dic = json.loads(raw)
            raw_dic["post_type"] = "downstream_call"
            raw_dic["time"] = int(time.time_ns() / 1e9)
            raw_dic["self_id"] = -1
            pak = DownstreamCallInPacket(time=raw_dic["time"], data=raw_dic)

            # 放入缓存,随后传递给适配器形成下游调用事件
            await self._in_buf.put(pak)
            out = await pak.to_upstream
            if out.is_forbidden():
                return

            # 替换回声标识,并存入映射表保存
            # 因为下游的回声标识可能和 melobot 内部的冲突,虽然概率很小
            down_seen_echo = out.echo
            up_seen_echo = get_id()
            out_data = json.dumps(
                {**out.get_dict(deepcopy=False), "echo": up_seen_echo}, ensure_ascii=False
            )
            out_pak = OutPacket(
                data=out_data,
                action_type=out.type,
                action_params=out.params,
                echo_id=up_seen_echo,
            )
            if len(self._echo_mapping) > 256:
                logger.warning(
                    f"{self.name} 反代 echo 标识映射表溢出,开始丢弃操作请求。请保证连接畅通或减少操作请求频率"
                )
                raise RuntimeError(f"{self.name} 反代 echo 标识映射表溢出,操作请求被丢弃")
            self._echo_mapping[up_seen_echo] = (out_pak.action_type, down_seen_echo)
            self._out_buf.put_nowait(out_pak)
        except Exception:
            logger.generic_exc(f"{self.name} 传递数据给上游时发生异常", obj={"raw": raw})

    async def open(self) -> None:
        await self._start()
        if self._rproxy is not None:
            await self._rproxy.open()

    def opened(self) -> bool:
        return self._opened.is_set()

    async def close(self) -> None:
        if self._rproxy is not None:
            await self._rproxy.close()
        await self._stop()


[文档] class WSClient(InstCounter, GenericIOLayer, WSClientImpl, BaseIOSource):
[文档] def __init__( self, url: str, max_retry: int = -1, retry_delay: float = 4.0, cd_time: float = 0, access_token: str | None = None, rproxy: "RProxyWSClient | RProxyWSServer | None" = None, *, name: str | None = None, ) -> None: InstCounter.__init__(self) GenericIOLayer.__init__(self, rproxy=rproxy) WSClientImpl.__init__( self, name=f"OB11 WS 客户端 #{self.INSTANCE_COUNT}" if name is None else name, url=url, req_headers=( None if access_token is None else {"Authorization": f"Bearer {access_token}"} ), max_retry=max_retry, retry_delay=retry_delay, ) BaseIOSource.__init__(self, cd_time=cd_time) self._hook_bus.set_tag(self.name)
[文档] class WSServer(InstCounter, GenericIOLayer, WSServerImpl, BaseIOSource):
[文档] def __init__( self, host: str, port: int, cd_time: float = 0, access_token: str | None = None, rproxy: "RProxyWSClient | RProxyWSServer | None" = None, *, name: str | None = None, ) -> None: InstCounter.__init__(self) GenericIOLayer.__init__(self, rproxy=rproxy) WSServerImpl.__init__( self, name=f"OB11 WS 服务端 #{self.INSTANCE_COUNT}" if name is None else name, host=host, port=port, ) BaseIOSource.__init__(self, cd_time=cd_time) self.access_token = access_token self._hook_bus.set_tag(self.name) self._req_lock = asyncio.Lock() self._conn_requested = False
async def _on_req(self, conn: ServerConnection, req: Request) -> Response | None: _headers = dict(req.headers) reconn_refused = "Already accepted the unique connection\n" auth_failed = "Authorization failed\n" if self._conn_requested: return conn.respond(http.HTTPStatus.FORBIDDEN, reconn_refused) async with self._req_lock: if self._conn_requested: return conn.respond(http.HTTPStatus.FORBIDDEN, reconn_refused) if ( self.access_token is not None and _headers.get("authorization") != f"Bearer {self.access_token}" and _headers.get("Authorization") != f"Bearer {self.access_token}" ): logger.warning(f"{self.name} ws 客户端请求的 access_token 不匹配,拒绝连接") return conn.respond(http.HTTPStatus.FORBIDDEN, auth_failed) self._conn_requested = True return None async def _on_unlinked(self, ws: ServerConnection) -> None: self._conn_requested = False