diff --git a/metaflow/runner/subprocess_manager.py b/metaflow/runner/subprocess_manager.py index 12468ae450..5cd7873057 100644 --- a/metaflow/runner/subprocess_manager.py +++ b/metaflow/runner/subprocess_manager.py @@ -10,23 +10,62 @@ from typing import Callable, Dict, Iterator, List, Optional, Tuple -def kill_process_and_descendants(pid, termination_timeout): +def send_signals(pid, signal): # TODO: there's a race condition that new descendants might # spawn b/w the invocations of 'pkill' and 'kill'. # Needs to be fixed in future. - try: - subprocess.check_call(["pkill", "-TERM", "-P", str(pid)]) - subprocess.check_call(["kill", "-TERM", str(pid)]) - except subprocess.CalledProcessError: - pass + retcode = subprocess.call(["pkill", signal, "-P", str(pid)]) + # 2: Invalid options + # 3: No processes matched + if retcode == 2 or retcode == 3: + print(f"'pkill {signal} -P {pid}' failed with return code: {retcode}.") + + retcode = subprocess.call(["kill", signal, str(pid)]) + if retcode != 0: + print(f"'kill {signal} {pid}' failed with return code: {retcode}.") + + +def kill_process_and_descendants(pid, termination_timeout): + send_signals(pid, "-TERM") + + time.sleep(termination_timeout) + + send_signals(pid, "-KILL") + + +def kill_processes_and_descendants(pids, termination_timeout): + for pid in pids: + send_signals(pid, "-TERM") time.sleep(termination_timeout) - try: - subprocess.check_call(["pkill", "-KILL", "-P", str(pid)]) - subprocess.check_call(["kill", "-KILL", str(pid)]) - except subprocess.CalledProcessError: - pass + for pid in pids: + send_signals(pid, "-KILL") + + +async def async_send_signals(pids, signal): + pkill_processes = [ + await asyncio.create_subprocess_exec("pkill", signal, "-P", str(pid)) + for pid in pids + ] + + for proc in pkill_processes: + await proc.wait() + + kill_processes = [ + await asyncio.create_subprocess_exec("kill", signal, str(pid)) for pid in pids + ] + + for proc in kill_processes: + await proc.wait() + + +async def async_kill_processes_and_descendants(pids, termination_timeout): + await async_send_signals(pids, "-TERM") + + await asyncio.sleep(termination_timeout) + + await async_send_signals(pids, "-KILL") class LogReadTimeoutError(Exception): @@ -42,6 +81,18 @@ class SubprocessManager(object): def __init__(self): self.commands: Dict[int, CommandManager] = {} + try: + + async def handle_sigint(): + await self._async_handle_sigint() + + asyncio.get_running_loop().add_signal_handler( + signal.SIGINT, lambda: asyncio.create_task(handle_sigint()) + ) + + except RuntimeError: + signal.signal(signal.SIGINT, self._handle_sigint) + async def __aenter__(self) -> "SubprocessManager": return self @@ -81,8 +132,12 @@ def run_command( """ command_obj = CommandManager(command, env, cwd) - pid = command_obj.run(show_output=show_output) + pid = command_obj.run(show_output=show_output, wait=False) + self.commands[pid] = command_obj + + command_obj.sync_wait() + return pid async def async_run_command( @@ -138,6 +193,42 @@ def cleanup(self) -> None: for v in self.commands.values(): v.cleanup() + async def kill(self, termination_timeout: float = 5): + """ + Kill all managed subprocesses and their descendants. + + Parameters + ---------- + termination_timeout : float, default 5 + The time to wait after sending a SIGTERM to a subprocess and its descendants + before sending a SIGKILL. + """ + + pids = [v.process.pid for v in self.commands.values() if v.process is not None] + await async_kill_processes_and_descendants(pids, 5) + + def sync_kill(self, termination_timeout: float = 5): + """ + Kill all managed subprocesses and their descendants synchronously. + + Parameters + ---------- + termination_timeout : float, default 5 + The time to wait after sending a SIGTERM to a subprocess and its descendants + before sending a SIGKILL. + """ + pids = [v.process.pid for v in self.commands.values() if v.process is not None] + kill_processes_and_descendants( + pids, + termination_timeout, + ) + + def _handle_sigint(self, signum, frame): + self.sync_kill() + + async def _async_handle_sigint(self): + await self.kill() + class CommandManager(object): """A manager for an individual subprocess.""" @@ -169,11 +260,11 @@ def __init__( self.cwd = cwd if cwd is not None else os.getcwd() self.process = None + self.stdout_thread = None + self.stderr_thread = None self.run_called: bool = False self.log_files: Dict[str, str] = {} - signal.signal(signal.SIGINT, self._handle_sigint) - async def __aenter__(self) -> "CommandManager": return self @@ -221,11 +312,23 @@ async def wait( "within %s seconds." % (self.process.pid, command_string, timeout) ) - def run(self, show_output: bool = False): + def sync_wait(self): """ - Run the subprocess synchronously. This can only be called once. + Wait for the subprocess to finish synchronously. - This also waits on the process implicitly. + You can only call `sync_wait` if `run` has already been called. + """ + + if not self.run_called: + raise RuntimeError("No command run yet to wait for...") + + self.process.wait() + self.stdout_thread.join() + self.stderr_thread.join() + + def run(self, show_output: bool = False, wait: bool = True) -> int: + """ + Run the subprocess synchronously. This can only be called once. Parameters ---------- @@ -234,6 +337,10 @@ def run(self, show_output: bool = False): They can be accessed later by reading the files present in: - self.log_files["stdout"] - self.log_files["stderr"] + wait : bool, default True + Wait for the process to finish before returning. + If false, the process will run in the background. You can then wait on + the process (using `sync_wait`) or kill it (using `sync_kill`). """ if not self.run_called: @@ -265,22 +372,22 @@ def stream_to_stdout_and_file(pipe, log_file): self.run_called = True - stdout_thread = threading.Thread( + self.stdout_thread = threading.Thread( target=stream_to_stdout_and_file, args=(self.process.stdout, stdout_logfile), ) - stderr_thread = threading.Thread( + self.stderr_thread = threading.Thread( target=stream_to_stdout_and_file, args=(self.process.stderr, stderr_logfile), ) - stdout_thread.start() - stderr_thread.start() - - self.process.wait() + self.stdout_thread.start() + self.stderr_thread.start() - stdout_thread.join() - stderr_thread.join() + if wait: + self.process.wait() + self.stdout_thread.join() + self.stderr_thread.join() return self.process.pid except Exception as e: @@ -457,8 +564,25 @@ async def kill(self, termination_timeout: float = 5): else: print("No process to kill.") - def _handle_sigint(self, signum, frame): - asyncio.create_task(self.kill()) + def sync_kill(self, termination_timeout: float = 5): + """ + Kill the subprocess and its descendants synchronously. + + Parameters + ---------- + termination_timeout : float, default 5 + The time to wait after sending a SIGTERM to the process and its descendants + before sending a SIGKILL. + """ + + if self.process is not None: + send_signals(self.process.pid, "-TERM") + + time.sleep(termination_timeout) + + send_signals(self.process.pid, "-KILL") + else: + print("No process to kill.") async def main():