melobot.mp 源代码

"""
此模块提供在 melobot 用例中对多进程的支持。
主要实现 spawn 模式子进程的自定义入口,以及避免默认序列化方式导致的级联加载。
基于对 spawn 模式子进程入口点的劫持,以及构建自定义可序列化对象。
只要不使用本模块的进程创建接口,就自动回退到 multiprocessing 原始逻辑
"""

# TODO: 由于存在侵入性设计,因此每个 py minor 版本都应该测试验证
import multiprocessing.spawn as spawn_mod
import multiprocessing.util as util_mod
import sys
from functools import wraps
from multiprocessing import current_process
from pathlib import Path

from typing_extensions import Any

from ._imp import ALL_EXTS

_PNAME_PREFIX = "MeloBot_MP"
MP_MODULE_NAME = "__mp_main__"
ROOT_MODULE_DIR = Path(__file__).parent.parent.resolve().as_posix()


def in_main_process() -> bool:
    """判断当前进程是否为主进程"""
    return current_process().name == "MainProcess"


def _wrapped_get_preparation_data(name: str) -> dict:
    data = _original_get_preparation_data(name)
    if SpawnProcess.owned(name):
        data["sys_path"].insert(0, _P_STATUS[name]["dir"])
        data["sys_argv"] = _P_STATUS[name]["argv"]
        data["dir"] = data["orig_dir"] = _P_STATUS[name]["dir"]

        sentinel = object()
        mod_name = data.get("init_main_from_name", sentinel)
        if mod_name is not sentinel and mod_name not in ("__main__", MP_MODULE_NAME):
            raise RuntimeError(
                f"子进程中 __main__ 模块从名称 {mod_name!r} 加载,这种情况下无法安全生成子进程"
            )
        data.pop("init_main_from_name", None)
        data["init_main_from_path"] = _P_STATUS[name]["entry"]
    return data


def _wrapped_get_command_line(**kwargs: Any) -> Any:
    if getattr(sys, "frozen", False):
        return [sys.executable, "--multiprocessing-fork"] + [
            "%s=%r" % item for item in kwargs.items()
        ]
    else:
        prog = (
            f"import sys; sys.path.insert(0, {ROOT_MODULE_DIR!r}); import melobot.mp; "
            "from multiprocessing.spawn import spawn_main; spawn_main(%s)"
        )
        prog %= ", ".join("%s=%r" % item for item in kwargs.items())
        opts = util_mod._args_from_interpreter_flags()  # type: ignore[attr-defined]
        return [spawn_mod.get_executable()] + opts + ["-c", prog, "--multiprocessing-fork"]


def _wrapped_prepare(data: Any) -> None:
    ret = _original_prepare(data)
    if SpawnProcess.owned(data["name"]):
        import signal

        # 默认重置常见的信号处理,子进程由父进程全权管理
        signals_to_ignore: list[signal.Signals] = [signal.SIGINT, signal.SIGTERM]
        if sys.platform == "win32":
            signals_to_ignore.append(signal.SIGBREAK)
        for sig in signals_to_ignore:
            signal.signal(sig, signal.SIG_IGN)
    return ret


_original_get_preparation_data = spawn_mod.get_preparation_data
spawn_mod.get_preparation_data = wraps(_original_get_preparation_data)(
    _wrapped_get_preparation_data
)

_original_get_command_line = spawn_mod.get_command_line
spawn_mod.get_command_line = wraps(_original_get_command_line)(_wrapped_get_command_line)

_original_prepare = spawn_mod.prepare
spawn_mod.prepare = wraps(_original_prepare)(_wrapped_prepare)


import pickle
from concurrent.futures import ProcessPoolExecutor as _ProcessPoolExecutor
from functools import partial
from multiprocessing import get_context
from multiprocessing.context import SpawnContext as _SpawnContext
from multiprocessing.pool import Pool
from os import PathLike
from os.path import normpath
from pathlib import Path
from threading import RLock
from types import FunctionType, MethodType, ModuleType

from typing_extensions import Callable, Iterable, Mapping, TypeAlias, TypedDict, cast


class _ProcessStatus(TypedDict):
    name: str
    entry: str
    argv: list[str]
    dir: str


_P_STATUS: dict[str, _ProcessStatus] = {}


[文档] class SpawnProcess(get_context("spawn").Process): # type: ignore[name-defined,misc] """melobot 进程类(使用 spawn 模式生成子进程)"""
[文档] def __init__( self, entry: str | PathLike[str] | Path, argv: list[str] | None = None, target: Callable[..., object] | None = None, name: str | None = None, args: Iterable[Any] | None = None, kwargs: Mapping[str, Any] | None = None, *, daemon: bool | None = None, ) -> None: """初始化一个子进程对象 :param entry: 子进程的入口模块(必须是文件) :param argv: 子进程的 `argv`,为空时使用默认设置 :param target: 子进程的目标函数 :param name: 子进程的标识名称,注意真实的进程名称与此参数不相等,只是包含 :param args: 目标函数的参数 :param kwargs: 目标函数的参数 :param daemon: 是否是守护进程 """ if not in_main_process(): raise RuntimeError( "不应该在 melobot 管理的子进程中继续创建子进程。" "出现此异常可能是因为初始化参数“入口路径”设置错误,导致创建进程的代码在子进程中再次执行" ) super().__init__( None, target, name, () if args is None else args, {} if kwargs is None else kwargs, daemon=daemon, ) order = self.name.split("-")[-1] self.init_name = name # 重设 name 属性,用于后续在 hack 中区分进程 self.main_part = f"{id(self):x}" if self.init_name is None else self.init_name self.name: str = f"{_PNAME_PREFIX}_{self.main_part}-{order}" try: entry_file = Path(entry).resolve(True) except FileNotFoundError as e: raise FileNotFoundError(f"子进程 {self.main_part} 的入口不存在: {entry!r}") from e if not entry_file.is_file() or not entry_file.as_posix().endswith(ALL_EXTS): raise ValueError(f"子进程 {self.main_part} 的入口不是可加载的文件: {entry_file!r}") # 再使用一次 normpath,保证与原始实现一样的兼容性 entry_norm_path = normpath(entry_file.as_posix()) _P_STATUS[self.name] = { "name": self.name, "entry": entry_norm_path, "argv": argv if argv is not None else [entry_norm_path], "dir": normpath(entry_file.parent.as_posix()), }
@staticmethod def owned(name: str) -> bool: return "melobot" in name.lower()
def in_melobot_sub_process() -> bool: """判断当前进程是否为 melobot 管理的子进程""" return SpawnProcess.owned(current_process().name) class _BanOriginalProcess: def __get__(self, *_: Any, **__: Any) -> None: raise AttributeError(f"内部尝试引用原始 Process 对象,请报告关于 {__name__} 模块的 bug") class SpawnContext(_SpawnContext): Process = cast(Any, _BanOriginalProcess()) def __init__(self, entry: str | PathLike[str] | Path, argv: list[str] | None = None) -> None: super().__init__() self.process_entry = entry self.process_argv = argv def __getattr__(self, name: str) -> Any: if name == "Process": return partial(SpawnProcess, self.process_entry, self.process_argv) raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
[文档] class SpawnProcessPool(Pool): """melobot 进程池类(使用 spawn 模式生成子进程)"""
[文档] def __init__( self, entry: str | PathLike[str] | Path, argv: list[str] | None = None, processes: int | None = None, initializer: Callable[..., object] | None = None, initargs: Iterable[Any] | None = None, maxtasksperchild: int | None = None, ) -> None: """初始化一个进程池对象 :param entry: 所有子进程的入口模块(必须是文件) :param argv: 所有子进程的 `argv`,为空时使用默认设置 :param processes: 进程数 :param initializer: 初始化函数 :param initargs: 初始化函数的参数 :param maxtasksperchild: 每个子进程执行的任务数,达到此数时销毁并生成新进程 """ if initargs is None: init_args = () super().__init__( processes, initializer, init_args, maxtasksperchild, SpawnContext(entry, argv) )
[文档] class SpawnProcessPoolExecutor(_ProcessPoolExecutor): """melobot 进程池执行器类(使用 spawn 模式生成子进程)"""
[文档] def __init__( self, entry: str | PathLike[str] | Path, argv: list[str] | None = None, max_workers: int | None = None, initializer: Callable[..., object] | None = None, initargs: Iterable[Any] | None = None, ) -> None: """初始化一个进程池执行器对象 :param entry: 所有子进程的入口模块(必须是文件) :param argv: 所有子进程的 `argv`,为空时使用默认设置 :param max_workers: worker(进程)的数量 :param initializer: 初始化函数 :param initargs: 初始化函数的参数 """ if initargs is None: init_args = () super().__init__(max_workers, SpawnContext(entry, argv), initializer, init_args)
Process: TypeAlias = SpawnProcess Context: TypeAlias = SpawnContext ProcessPool: TypeAlias = SpawnProcessPool ProcessPoolExecutor: TypeAlias = SpawnProcessPoolExecutor _EMPTY = object() _DUMMY_CLS = type("_DUMMY_CLS", (), {}) # 注意这个递归锁在序列化和反序列化时,不是同一个(在不同进程中) _PICKLE_RLOCK = RLock()
[文档] class PBox:
[文档] def __init__( self, value: Any = _EMPTY, name: str | None = None, module: str | None = None, entry: str | PathLike[str] | Path | None = None, ) -> None: """pickle 包装器 更改或指定需要 pickle 的对象的来源 :param value: 和 `name` 二选一,指定需要 pickle 的对象,默认值为一个哨兵对象(表示无值) :param name: 和 `value` 二选一,指定需要 pickle 的对象的名称 :param module: pickle 的对象的来源模块名,为空时映射到进程的入口模块 :param entry: pickle 的对象的来源模块的路径,如果为空则只依赖于模块名进行加载 """ if value is _EMPTY and name is None: raise ValueError("值参数和名称参数不能同时为空") if value is not _EMPTY and name is not None: raise ValueError("值参数和名称参数不能同时存在") if isinstance(value, MethodType): raise ValueError("类或实例的方法不支持 pickle,请尝试 pickle 整个类或实例而不是方法") # 序列化后的 bytes self.value: bytes # 是否有值,如果无值则反序列化时不使用 value 属性,而是直接从模块中提取 self.has_value: bool # 反序列化前,预先加载的模块名 self.module = cast(str, MP_MODULE_NAME if module in ("", None) else module) # 反序列化时,预先加载的模块的文件路径,为空时只依赖模块名加载模块 self.entry: str | None if entry is not None: try: abs_entry_path = Path(entry).resolve(True) except FileNotFoundError as e: raise FileNotFoundError(f"入口路径 {entry!r} 不存在") from e self.entry = abs_entry_path.as_posix() if self.module != MP_MODULE_NAME: mod_parts = tuple(self.module.split(".")) if "" in mod_parts: raise ValueError(f"模块名 {module!r} 有误或存在相对导入语义") path_parts = abs_entry_path.parts[:-1] + (abs_entry_path.stem,) if path_parts[-len(mod_parts) :] != mod_parts: raise ValueError( "模块名 split('.') 后的序列,不是路径绝对化并去除扩展名后 split('/') 得到序列的尾子序列" ) else: self.entry = None # 这些属性在序列化时会被排除在外 self._serial_args = (value, name) self._orig_value = value
def _serialize(self, value: Any, name: str | None) -> None: if value is _EMPTY: self.has_value = False name = cast(str, name) self.value = name.encode("utf-8") return with _PICKLE_RLOCK: self.has_value = True name_owner: FunctionType | type if getattr(value, "__qualname__", None) not in (None, ""): qname = value.__qualname__ name_owner = value real_mod = getattr(value, "__module__", None) if real_mod in (None, ""): raise ValueError(f"对象 {value} 所属模块无法找到,因此不能被 pickle") else: cls = getattr(value, "__class__", _EMPTY) if cls is _EMPTY: raise ValueError(f"对象 {value} 的类无法找到,因此不能被 pickle") _qname = cast(str | None, getattr(cls, "__qualname__", None)) if _qname in (None, ""): raise ValueError(f"对象 {value} 的类没有有效的名称,因此不能被 pickle") qname = cast(str, _qname) name_owner = cast(type, cls) real_mod = getattr(cls, "__module__", None) if real_mod in (None, ""): raise ValueError(f"对象 {value} 的类所属模块无法找到,因此不能被 pickle") real_mod = cast(str | None, real_mod) real_name = getattr(name_owner, "__name__", _EMPTY) real_qname = getattr(name_owner, "__qualname__", _EMPTY) # 获取 pickle 要验证的模块,如果存在则暂时保存 mod = sys.modules.get(self.module) orig_mod = mod if mod is not None else None fake_mod = ModuleType(self.module) # 逐级生成属性,直到最后一层 parts = qname.split(".") idx = 0 node = fake_mod while idx < len(parts) - 1: setattr(node, parts[idx], _DUMMY_CLS()) node = getattr(node, parts[idx]) idx += 1 # 最后一层填充实际值 setattr(node, parts[-1], name_owner) # 对于实际的对象,构建临时环境来欺骗 pickle sys_modified = False try: name_owner.__qualname__ = qname name_owner.__name__ = qname.split(".")[-1] name_owner.__module__ = self.module sys.modules[self.module] = fake_mod sys_modified = True self.value = pickle.dumps(value) finally: if sys_modified: if orig_mod is None: del sys.modules[self.module] else: sys.modules[self.module] = orig_mod if real_mod is None: if hasattr(name_owner, "__module__"): del name_owner.__module__ else: name_owner.__module__ = real_mod if real_name is _EMPTY: if hasattr(name_owner, "__name__"): del name_owner.__name__ else: name_owner.__name__ = cast(str, real_name) if real_qname is _EMPTY: if hasattr(name_owner, "__qualname__"): del name_owner.__qualname__ else: name_owner.__qualname__ = cast(str, real_qname) def __reduce__(self) -> tuple[Callable, tuple[Any, ...]]: try: self._serialize(*self._serial_args) except Exception as e: if self._orig_value is _EMPTY: raise pickle.PicklingError(f"Pickle 失败,{e}") from e else: raise pickle.PicklingError(f"Pickle 对象 {self._orig_value} 失败,{e}") from e else: return (_deserialize, (self.value, self.has_value, self.module, self.entry))
def _deserialize(value: bytes, has_value: bool, module: str, entry: str | None) -> Any: from ._imp import Importer with _PICKLE_RLOCK: try: dir_path = ( Path(entry).resolve(strict=True).parent.as_posix() if entry is not None else None ) mod = Importer.import_mod(module, dir_path) if has_value: return pickle.loads(value) else: return getattr(mod, value.decode("utf-8")) except Exception as e: raise pickle.UnpicklingError(f"Unpickle 失败,{e}") from e