melobot.protocols.onebot.v11.io.forward 源代码
import asyncio
import json
import time
from asyncio import Future, Lock
from itertools import count
import websockets
from websockets.asyncio.client import ClientConnection
from websockets.exceptions import ConnectionClosed
from melobot.exceptions import SourceError
from melobot.io import SourceLifeSpan
from melobot.log import LogLevel
from .base import BaseIOSource
from .packet import EchoPacket, InPacket, OutPacket
[文档]
class ForwardWebSocketIO(BaseIOSource):
[文档]
def __init__(
self,
url: str,
max_retry: int = -1,
retry_delay: float = 4.0,
cd_time: float = 0,
access_token: str | None = None,
) -> None:
super().__init__(cd_time)
self.url = url
self.conn: ClientConnection
self.access_token = access_token
self.max_retry: int = max_retry
self.retry_delay: float = retry_delay if retry_delay > 0 else 0.5
self._in_buf: asyncio.Queue[InPacket] = asyncio.Queue()
self._out_buf: asyncio.Queue[OutPacket] = asyncio.Queue()
self._pre_send_time = time.time_ns()
self._echo_table: dict[str, tuple[str, Future[EchoPacket]]] = {}
self._tasks: list[asyncio.Task] = []
self._opened = asyncio.Event()
self._lock = Lock()
self._restart_flag = asyncio.Event()
async def _input_loop(self) -> None:
while True:
try:
await self._opened.wait()
raw_str = await self.conn.recv()
self.logger.generic_obj("收到上报,未格式化的字符串", raw_str, level=LogLevel.DEBUG)
if raw_str == "":
continue
raw = json.loads(raw_str)
if "post_type" in raw:
await self._in_buf.put(InPacket(time=raw["time"], data=raw))
continue
echo_id = raw.get("echo")
if echo_id in (None, ""):
continue
action_type, fut = self._echo_table.pop(echo_id)
fut.set_result(
EchoPacket(
time=int(time.time()),
data=raw,
ok=raw["status"] == "ok",
status=raw["retcode"],
action_type=action_type,
)
)
except asyncio.CancelledError:
break
except ConnectionClosed:
if self.opened():
self._restart_flag.set()
asyncio.create_task(self.close())
break
except Exception:
self.logger.exception("OneBot v11 正向 WebSocket IO 源输入异常")
self.logger.generic_obj("异常点局部变量", locals(), level=LogLevel.ERROR)
async def _output_loop(self) -> None:
while True:
try:
await self._opened.wait()
out_packet = await self._out_buf.get()
wait_time = self.cd_time - ((time.time_ns() - self._pre_send_time) / 1e9)
await asyncio.sleep(wait_time)
await self.conn.send(out_packet.data)
self._pre_send_time = time.time_ns()
except asyncio.CancelledError:
break
except Exception:
self.logger.exception("OneBot v11 正向 WebSocket IO 源输出异常")
self.logger.generic_obj("异常点局部变量", locals(), level=LogLevel.ERROR)
async def open(self) -> None:
if self.opened():
return
async with self._lock:
if self.opened():
return
headers: dict | None = None
if self.access_token is not None:
headers = {"Authorization": f"Bearer {self.access_token}"}
retry_iter = count(0) if self.max_retry < 0 else range(self.max_retry + 1)
for _ in retry_iter:
try:
self.conn = await websockets.connect(self.url, additional_headers=headers)
break
except asyncio.CancelledError:
raise
except BaseException as e:
self.logger.warning(f"连接建立失败,{self.retry_delay}s 后自动重试。错误:{e}")
if "403" in str(e):
self.logger.warning("403 错误可能是 access_token 未配置或无效")
await asyncio.sleep(self.retry_delay)
else:
raise SourceError("OneBot v11 正向 WebSocket IO 源重试已达最大次数,已放弃建立连接")
self._tasks.append(asyncio.create_task(self._input_loop()))
self._tasks.append(asyncio.create_task(self._output_loop()))
self._opened.set()
self.logger.info("OneBot v11 正向 WebSocket IO 源与实现端建立了连接")
if self._restart_flag.is_set():
self._restart_flag.clear()
await self._hook_bus.emit(SourceLifeSpan.RESTARTED, False)
[文档]
def opened(self) -> bool:
return self._opened.is_set()
async def close(self) -> None:
if not self.opened():
return
async with self._lock:
if not self.opened():
return
self.conn.close_timeout = 2
self._opened.clear()
await self.conn.close()
await self.conn.wait_closed()
for t in self._tasks:
t.cancel()
if len(self._tasks):
await asyncio.wait(self._tasks)
self._tasks.clear()
self.logger.info("OneBot v11 正向 WebSocket IO 源已断开连接")
if self._restart_flag.is_set():
asyncio.create_task(self.open())
async def input(self) -> InPacket:
return await self._in_buf.get()
async def output(self, packet: OutPacket) -> EchoPacket:
await self._out_buf.put(packet)
if packet.echo_id is None:
return EchoPacket(noecho=True)
fut: Future[EchoPacket] = asyncio.get_running_loop().create_future()
self._echo_table[packet.echo_id] = (packet.action_type, fut)
return await fut