from __future__ import annotations
import asyncio
import inspect
from asyncio import Condition, Future, Lock
from contextlib import _AsyncGeneratorContextManager, asynccontextmanager
from typing_extensions import Any, AsyncGenerator
from ..adapter.model import Event
from ..ctx import FlowCtx, SessionCtx
from ..exceptions import SessionRuleLacked, SessionStateFailed
from ..handle.base import EventCompletion, stop
from ..typ.base import SyncOrAsyncCallable
from .option import CompareInfo, Rule
_SESSION_CTX = SessionCtx()
class SessionState:
def __init__(self, session: "Session") -> None:
self.session = session
async def work(self, completion: EventCompletion) -> None:
raise SessionStateFailed(self.__class__.__name__, SessionState.work.__name__)
async def rest(self) -> None:
raise SessionStateFailed(self.__class__.__name__, SessionState.rest.__name__)
async def suspend(self, timeout: float | None) -> bool:
raise SessionStateFailed(self.__class__.__name__, SessionState.suspend.__name__)
async def wakeup(self, completion: EventCompletion | None) -> None:
raise SessionStateFailed(self.__class__.__name__, SessionState.wakeup.__name__)
async def expire(self) -> None:
raise SessionStateFailed(self.__class__.__name__, SessionState.expire.__name__)
class SpareSessionState(SessionState):
async def work(self, completion: EventCompletion) -> None:
self.session._completions.add(completion)
self.session.event = completion.event
self.session.__to_state__(WorkingSessionState)
class WorkingSessionState(SessionState):
async def rest(self) -> None:
if self.session.rule is None:
raise SessionRuleLacked("缺少会话规则,会话无法从“运行态”转为“空闲态”")
cond = self.session._refresh_cond
self.session.__to_state__(SpareSessionState)
async with cond:
cond.notify()
async def suspend(self, timeout: float | None) -> bool:
self.session.__try_auto_complete__()
if self.session.rule is None:
raise SessionRuleLacked("缺少会话规则,会话无法从“运行态”转为“挂起态”")
cond = self.session._refresh_cond
self.session.__to_state__(SuspendSessionState)
async with cond:
cond.notify()
async with self.session._wakeup_cond:
if timeout is None:
await self.session._wakeup_cond.wait()
return True
try:
await asyncio.wait_for(self.session._wakeup_cond.wait(), timeout=timeout)
return True
except asyncio.TimeoutError:
if self.session.__is_state__(WorkingSessionState):
return True
self.session.__to_state__(WorkingSessionState)
return False
async def expire(self) -> None:
self.session.__to_state__(ExpireSessionState)
if self.session.rule is not None:
cond = self.session._refresh_cond
async with cond:
cond.notify()
self.session.set_completed()
class SuspendSessionState(SessionState):
async def wakeup(self, completion: EventCompletion | None) -> None:
if self.session.__is_state__(WorkingSessionState):
return
if completion is not None:
self.session._completions.add(completion)
self.session.event = completion.event
cond = self.session._wakeup_cond
self.session.__to_state__(WorkingSessionState)
async with cond:
cond.notify()
class ExpireSessionState(SessionState): ...
[文档]
class SessionStore(dict[str, Any]):
"""会话存储,生命周期伴随会话对象"""
[文档]
def set(self, key: str, value: Any) -> None:
self[key] = value
[文档]
class Session:
"""会话
:ivar SessionStore store: 当前会话上下文的会话存储
:ivar Rule rule: 当前会话上下文的会话规则
"""
__instances__: dict[Rule, set["Session"]] = {}
__instance_locks__: dict[Rule, Lock] = {}
__cls_lock__ = Lock()
def __init__(
self,
rule: Rule | None,
first_completion: EventCompletion,
keep: bool = False,
auto_complete: bool = True,
) -> None:
self.store: SessionStore = SessionStore()
self.event = first_completion.event
self.rule = rule
self.auto_complete = auto_complete
self._completions: set[EventCompletion] = set()
self._completions.add(first_completion)
self._refresh_cond = Condition()
self._wakeup_cond = Condition()
self._keep = keep
self._state: SessionState = WorkingSessionState(self)
[文档]
def stop_keep(self) -> None:
"""停止会话保持
当进入会话时,启用了 `keep=True`,
需要在会话不需要保持后,手动调用此方法标识会话可以销毁
"""
self._keep = False
[文档]
def set_completed(self, event: Event | None = None) -> None:
"""标志会话中的事件为完成状态
`event` 参数为空,自动标志会话历史中所有事件为完成状态
事件被标志为完成状态后,事件才可能向更低优先级传播。但具体是否可以,
还和其他处理流的操作有关
举例来说,假设一个事件触发了一批处理流,如果这批处理流都没有启用会话,
处理流结束后,将会自动标志事件“完成”。事件是否可以传播,
将在所有处理流完成后评估
但如果这批处理流中有启用会话的,由于会话是可以挂起的,
这意味着非常长的处理周期,因此需要会话来标志事件“完成”,
这样后续的传播评估才能在合适的时机进行
使用 :func:`enter_session` 进入会话时,设置 `auto_complete=True`,
会话将在每次挂起后,自动标志当前事件为完成状态,这样事件就不会被会话“囚禁”,
而迟迟无法传播到下一优先级。而设置为 `False`,则需要手动调用此方法来标志“完成”
:param event: 事件
"""
if event is None:
for c in self._completions:
c.completed.set_result(None)
self._completions.clear()
return
comps = filter(lambda c: c.event is event, self._completions)
for c in comps:
c.completed.set_result(None)
self._completions.remove(c)
[文档]
def get_incompletions(self) -> list[tuple[Event, Future]]:
"""获取会话历史中所有未完成的事件组
:return: 元组 (事件, 事件“完成”的信号) 组成的列表
"""
return [(c.event, c.completed) for c in self._completions if not c.completed.done()]
def __try_auto_complete__(self) -> None:
if self.auto_complete:
self.set_completed()
def __to_state__(self, state_class: type[SessionState]) -> None:
self._state = state_class(self)
def __is_state__(self, state_class: type[SessionState]) -> bool:
return isinstance(self._state, state_class)
async def __work__(self, completion: EventCompletion) -> None:
await self._state.work(completion)
async def __rest__(self) -> None:
await self._state.rest()
async def __suspend__(self, timeout: float | None = None) -> bool:
return await self._state.suspend(timeout)
async def __wakeup__(self, completion: EventCompletion | None) -> None:
await self._state.wakeup(completion)
async def __expire__(self) -> None:
await self._state.expire()
@classmethod
async def get(
cls,
completion: EventCompletion,
rule: Rule | None = None,
wait: bool = True,
nowait_cb: SyncOrAsyncCallable[[], None] | None = None,
keep: bool = False,
auto_complete: bool = True,
) -> Session | None:
event = completion.event
if rule is None:
return Session(
rule=None,
first_completion=completion,
keep=keep,
auto_complete=auto_complete,
)
async with cls.__cls_lock__:
cls.__instance_locks__.setdefault(rule, Lock())
async with cls.__instance_locks__[rule]:
try:
_set = cls.__instances__.setdefault(rule, set())
suspends = filter(lambda s: s.__is_state__(SuspendSessionState), _set)
for session in suspends:
if await rule.compare_with(CompareInfo(session, session.event, event)):
await session.__wakeup__(completion)
return None
spares = filter(lambda s: s.__is_state__(SpareSessionState), _set)
for session in spares:
if await rule.compare_with(CompareInfo(session, session.event, event)):
await session.__work__(completion)
session._keep = keep
return session
workings = filter(lambda s: s.__is_state__(WorkingSessionState), _set)
for session in workings:
if not await rule.compare_with(CompareInfo(session, session.event, event)):
continue
if not wait:
if nowait_cb is not None:
ret = nowait_cb()
if inspect.isawaitable(ret):
await ret
completion.completed.set_result(None)
return None
cond = session._refresh_cond
async with cond:
await cond.wait()
if session.__is_state__(ExpireSessionState):
pass
elif session.__is_state__(SuspendSessionState):
await session.__wakeup__(completion)
return None
else:
await session.__work__(completion)
session._keep = keep
return session
session = Session(
rule=rule,
first_completion=completion,
keep=keep,
auto_complete=auto_complete,
)
Session.__instances__[rule].add(session)
return session
finally:
expires = tuple(filter(lambda s: s.__is_state__(ExpireSessionState), _set))
for session in expires:
Session.__instances__[rule].remove(session)
@classmethod
@asynccontextmanager
async def enter(
cls,
rule: Rule,
wait: bool = True,
nowait_cb: SyncOrAsyncCallable[[], None] | None = None,
keep: bool = False,
auto_complete: bool = True,
) -> AsyncGenerator[Session, None]:
flow_ctx = FlowCtx()
completion = flow_ctx.get_completion()
completion.under_session = True
session = await cls.get(
completion,
rule=rule,
wait=wait,
nowait_cb=nowait_cb,
keep=keep,
auto_complete=auto_complete,
)
if session is None:
await stop()
with _SESSION_CTX.unfold(session):
try:
yield session
except asyncio.CancelledError:
if session.__is_state__(SuspendSessionState):
await session.__wakeup__(completion=None)
finally:
if session._keep:
await session.__rest__()
else:
await session.__expire__()
[文档]
async def suspend(timeout: float | None = None) -> bool:
"""挂起当前会话
:param timeout: 挂起后再唤醒的超时时间, 为空则永不超时
:return: 如果为 `False` 则表明唤醒超时
"""
return await _SESSION_CTX.get().__suspend__(timeout)
[文档]
def enter_session(
rule: Rule,
wait: bool = True,
nowait_cb: SyncOrAsyncCallable[[], None] | None = None,
keep: bool = False,
auto_complete: bool = True,
) -> _AsyncGeneratorContextManager[Session]:
"""上下文管理器,提供一个会话上下文,在此上下文中可使用会话的高级特性
:param rule: 会话规则
:param wait: 当出现会话冲突时,是否需要等待
:param nowait_cb: 指定了 `wait=False` 后,会话冲突时执行的回调
:param keep: 会话在退出会话上下文后是否继续保持
:param auto_complete: 当前会话挂起后,事件是否自动标记为“完成”状态。其他有关细节参考 :meth:`Session.set_completed`
:yield: 会话对象
"""
return Session.enter(rule, wait, nowait_cb, keep, auto_complete)