melobot.protocols.onebot.v11.io.ws_rproxy 源代码

from __future__ import annotations

import asyncio
import http

from typing_extensions import Any, Callable, Coroutine
from websockets.asyncio.server import ServerConnection
from websockets.http11 import Request, Response

from melobot.log import logger

from .base import InstCounter
from .ws import GenericIOLayer
from .ws_impl import WSClientImpl, WSServerImpl


class GenericRProxyLayer:
    def __init__(self) -> None:
        self.io_src: GenericIOLayer | None = None
        self.to_downstream_buf: asyncio.Queue[str] = asyncio.Queue()

        # 在继承具体的实现类后拥有这些属性
        self.name: str
        self._start: Callable[[], Coroutine[Any, Any, None]]
        self._stop: Callable[[], Coroutine[Any, Any, None]]
        self._bound = False

    def bind_src(self, src: GenericIOLayer) -> None:
        if self._bound:
            raise RuntimeError(f"{self.name} 已经绑定了一个源对象,不能重复绑定")
        self.io_src = src
        self._bound = True

    async def _on_received(self, raw: str | bytes) -> None:
        if self.io_src is None:
            logger.warning(f"{self.name} 没有绑定源对象,将丢弃收到的数据")
            return
        asyncio.create_task(self.io_src._to_upstream(raw))

    async def _on_get_output(self) -> str | bytes:
        raw = await self.to_downstream_buf.get()
        return raw

    async def open(self) -> None:
        await self._start()

    async def close(self) -> None:
        await self._stop()

    def to_downstream(self, raw: str) -> None:
        if self.to_downstream_buf.qsize() > 100:
            logger.warning(
                f"{self.name} 输出缓冲区溢出,开始丢弃发送到下游的数据。请保证连接畅通或减少数据发送频率"
            )
            raise RuntimeError("输出缓冲区溢出,发送到下游的数据被丢弃")
        self.to_downstream_buf.put_nowait(raw)


[文档] class RProxyWSClient(InstCounter, GenericRProxyLayer, WSClientImpl):
[文档] def __init__( self, url: str, max_retry: int = -1, retry_delay: float = 4.0, access_token: str | None = None, *, name: str | None = None, ) -> None: InstCounter.__init__(self) GenericRProxyLayer.__init__(self) WSClientImpl.__init__( self, name=f"OB11 反代/WS 客户端 #{self.INSTANCE_COUNT}" if name is None else name, url=url, req_headers=( None if access_token is None else {"Authorization": f"Bearer {access_token}"} ), max_retry=max_retry, retry_delay=retry_delay, )
[文档] class RProxyWSServer(InstCounter, GenericRProxyLayer, WSServerImpl):
[文档] def __init__( self, host: str, port: int, access_token: str | None = None, *, name: str | None = None ) -> None: InstCounter.__init__(self) GenericRProxyLayer.__init__(self) WSServerImpl.__init__( self, name=f"OB11 反代/WS 服务端 #{self.INSTANCE_COUNT}" if name is None else name, host=host, port=port, ) self.access_token = access_token self._req_lock = asyncio.Lock() self._conn_requested = False
async def _on_req(self, conn: ServerConnection, req: Request) -> Response | None: _headers = dict(req.headers) reconn_refused = "Already accepted the unique connection\n" auth_failed = "Authorization failed\n" if self._conn_requested: return conn.respond(http.HTTPStatus.FORBIDDEN, reconn_refused) async with self._req_lock: if self._conn_requested: return conn.respond(http.HTTPStatus.FORBIDDEN, reconn_refused) if ( self.access_token is not None and _headers.get("authorization") != f"Bearer {self.access_token}" and _headers.get("Authorization") != f"Bearer {self.access_token}" ): logger.warning(f"{self.name} ws 客户端请求的 access_token 不匹配,拒绝连接") return conn.respond(http.HTTPStatus.FORBIDDEN, auth_failed) self._conn_requested = True return None async def _on_unlinked(self, ws: ServerConnection) -> None: self._conn_requested = False