From 64e3c7d21e8d6cb69a077c31fcf68c564473969f Mon Sep 17 00:00:00 2001 From: Filip Cacky Date: Tue, 24 Sep 2024 23:59:43 +0200 Subject: [PATCH] Add SignalManager --- metaflow/__init__.py | 1 + metaflow/runner/metaflow_runner.py | 4 +- metaflow/runner/signal_manager.py | 128 ++++++++++++++++++++++++++ metaflow/runner/subprocess_manager.py | 19 ++-- 4 files changed, 143 insertions(+), 9 deletions(-) create mode 100644 metaflow/runner/signal_manager.py diff --git a/metaflow/__init__.py b/metaflow/__init__.py index c901e81c38..9df9e18ad6 100644 --- a/metaflow/__init__.py +++ b/metaflow/__init__.py @@ -147,6 +147,7 @@ class and related decorators. # Runner API if sys.version_info >= (3, 7): from .runner.metaflow_runner import Runner + from .runner.signal_manager import SignalManager from .runner.nbrun import NBRunner from .runner.deployer import Deployer from .runner.nbdeploy import NBDeployer diff --git a/metaflow/runner/metaflow_runner.py b/metaflow/runner/metaflow_runner.py index 450cae336c..fb1b44d303 100644 --- a/metaflow/runner/metaflow_runner.py +++ b/metaflow/runner/metaflow_runner.py @@ -10,6 +10,7 @@ from .utils import handle_timeout, async_handle_timeout, clear_and_set_os_environ from .subprocess_manager import CommandManager, SubprocessManager +from .signal_manager import SignalManager class ExecutingRun(object): @@ -231,6 +232,7 @@ def __init__( env: Optional[Dict] = None, cwd: Optional[str] = None, file_read_timeout: int = 3600, + signal_manager: Optional[SignalManager] = None, **kwargs ): # these imports are required here and not at the top @@ -257,7 +259,7 @@ def __init__( self.cwd = cwd self.file_read_timeout = file_read_timeout - self.spm = SubprocessManager() + self.spm = SubprocessManager(signal_manager=signal_manager) self.top_level_kwargs = kwargs self.api = MetaflowAPI.from_cli(self.flow_file, start) diff --git a/metaflow/runner/signal_manager.py b/metaflow/runner/signal_manager.py new file mode 100644 index 0000000000..8f177e5260 --- /dev/null +++ b/metaflow/runner/signal_manager.py @@ -0,0 +1,128 @@ +import asyncio +import signal +from typing import NewType, Mapping, Set, Callable, Optional + +SignalHandler = NewType("SignalHandler", Callable[[int, []], None]) + + +class SignalManager: + """ + A context manager for managing signal handlers. + + This class works as a context manager, restoring any overwritten + signal handlers when the context is exited. This only works for signals + in a synchronous context (ie. hooked by `signal`). + + Parameters + ---------- + hook_signals : bool + If True, the signal manager will overwrite any existing signal handlers + in either `asyncio` or `signal`. If you already have any signal + handling in place, you can set this to False and use `trigger_signal` + to trigger metaflow-related signal handlers. + event_loop : Optional[asyncio.AbstractEventLoop] + The event loop to use for handling signals. + If None, the current running event loop is used, if any. + """ + + hook_signals: bool + event_loop: Optional[asyncio.AbstractEventLoop] + signal_map: Mapping[int, Set[SignalHandler]] = dict() + replaced_signals: Mapping[int, SignalHandler] = dict() + + def __init__( + self, + hook_signals: bool = True, + event_loop: Optional[asyncio.AbstractEventLoop] = None, + ): + self.hook_signals = hook_signals + try: + self.event_loop = event_loop or asyncio.get_running_loop() + except RuntimeError: + self.event_loop = None + + def __exit__(self, exc_type, exc_value, traceback): + for sig in self.signal_map: + self._maybe_remove_signal_handler(sig) + + for sig in self.replaced_signals: + signal.signal(sig, self.replaced_signals[sig]) + + def _handle_signal(self, signum, frame): + for handler in self.signal_map[signum]: + handler(signum, frame) + + def _maybe_add_signal_handler(self, sig): + if not self.hook_signals: + return + + if self.event_loop is None: + replaced = signal.signal(sig, self._handle_signal) + self.replaced_signals[sig] = replaced + + else: + self.event_loop.add_signal_handler( + sig, lambda: self._handle_signal(sig, None) + ) + + def _maybe_remove_signal_handler(self, sig: int): + if not self.hook_signals: + return + + if self.event_loop is None: + signal.signal(sig, self.replaced_signals[sig]) + del self.replaced_signals[sig] + else: + self.event_loop.remove_signal_handler(sig) + + def add_signal_handler(self, sig: int, handler: SignalHandler): + """ + Add a signal handler for the given signal. + + Parameters + ---------- + sig: int + The signal to handle. + handler: SignalHandler + The handler to call when the signal is received. + """ + if sig not in self.signal_map: + self.signal_map[sig] = set() + self._maybe_add_signal_handler(sig) + + self.signal_map[sig].add(handler) + + def remove_signal_handler(self, sig: signal.Signals, handler: SignalHandler): + """ + Remove a signal handler for the given signal. + + Parameters + ---------- + sig: int + The signal to handle. + handler: SignalHandler + The handler to remove. + + Raises + ------ + KeyError + If the signal `sig` is not being handled. + """ + if sig not in self.signal_map: + return + + self.signal_map[sig].discard(handler) + + def trigger_signal(self, sig: int, frame=None): + """ + Trigger a signal handler for the given signal. + + Parameters + ---------- + sig : int + The signal to handle. + frame : [] (optional) + The frame to pass to the signal handler. + Only used in a synchronous context. + """ + self._handle_signal(sig, frame) diff --git a/metaflow/runner/subprocess_manager.py b/metaflow/runner/subprocess_manager.py index ad24c7ca8c..705dc6cfd1 100644 --- a/metaflow/runner/subprocess_manager.py +++ b/metaflow/runner/subprocess_manager.py @@ -9,6 +9,8 @@ import threading from typing import Callable, Dict, Iterator, List, Optional, Tuple +from .signal_manager import SignalManager + def kill_process_and_descendants(pid, termination_timeout): # TODO: there's a race condition that new descendants might @@ -73,17 +75,17 @@ class SubprocessManager(object): CommandManager objects, each of which manages an individual subprocess. """ - def __init__(self): + def __init__(self, signal_manager: SignalManager): self.commands: Dict[int, CommandManager] = {} + self.signal_manager = signal_manager or SignalManager() - try: - loop = asyncio.get_running_loop() - loop.add_signal_handler( + if self.signal_manager.event_loop is not None: + self.signal_manager.add_signal_handler( signal.SIGINT, - lambda: asyncio.create_task(self._async_handle_sigint()), + lambda s, f: asyncio.create_task(self._async_handle_sigint()), ) - except RuntimeError: - signal.signal(signal.SIGINT, self._handle_sigint) + else: + self.signal_manager.add_signal_handler(signal.SIGINT, self._handle_sigint) async def _async_handle_sigint(self): pids = [ @@ -193,7 +195,8 @@ def get(self, pid: int) -> Optional["CommandManager"]: return self.commands.get(pid, None) def cleanup(self) -> None: - """Clean up log files for all running subprocesses.""" + """Clean up signal handler and log files for all running subprocesses.""" + self.signal_manager.remove_signal_handler(signal.SIGINT, self.signal_handler) for v in self.commands.values(): v.cleanup()