diff --git a/metaflow/runner/metaflow_runner.py b/metaflow/runner/metaflow_runner.py index 61247d19cf..209fc8c948 100644 --- a/metaflow/runner/metaflow_runner.py +++ b/metaflow/runner/metaflow_runner.py @@ -8,7 +8,11 @@ from metaflow import Run, metadata -from .utils import clear_and_set_os_environ, read_from_file_when_ready +from .utils import ( + clear_and_set_os_environ, + read_from_file_when_ready, + async_read_from_file_when_ready, +) from .subprocess_manager import CommandManager, SubprocessManager @@ -294,6 +298,28 @@ def __get_executing_run(self, tfp_runner_attribute, command_obj): error_message += "\nStderr:\n%s\n" % stderr_log raise RuntimeError(error_message) from e + async def __async_get_executing_run(self, tfp_runner_attribute, command_obj): + try: + clear_and_set_os_environ(self.old_env) + + content = await async_read_from_file_when_ready( + tfp_runner_attribute.name, command_obj, timeout=self.file_read_timeout + ) + metadata_for_flow, pathspec = content.rsplit(":", maxsplit=1) + metadata(metadata_for_flow) + run_object = Run(pathspec, _namespace_check=False) + return ExecutingRun(self, command_obj, run_object) + except (CalledProcessError, TimeoutError) as e: + stdout_log = open(command_obj.log_files["stdout"]).read() + stderr_log = open(command_obj.log_files["stderr"]).read() + command = " ".join(command_obj.command) + error_message = "Error executing: '%s':\n" % command + if stdout_log.strip(): + error_message += "\nStdout:\n%s\n" % stdout_log + if stderr_log.strip(): + error_message += "\nStderr:\n%s\n" % stderr_log + raise RuntimeError(error_message) from e + def run(self, **kwargs) -> ExecutingRun: """ Blocking execution of the run. This method will wait until @@ -395,7 +421,9 @@ async def async_run(self, **kwargs) -> ExecutingRun: ) command_obj = self.spm.get(pid) - return self.__get_executing_run(tfp_runner_attribute, command_obj) + return await self.__async_get_executing_run( + tfp_runner_attribute, command_obj + ) async def async_resume(self, **kwargs): """ @@ -430,7 +458,9 @@ async def async_resume(self, **kwargs): ) command_obj = self.spm.get(pid) - return self.__get_executing_run(tfp_runner_attribute, command_obj) + return await self.__async_get_executing_run( + tfp_runner_attribute, command_obj + ) def __exit__(self, exc_type, exc_value, traceback): self.spm.cleanup() diff --git a/metaflow/runner/utils.py b/metaflow/runner/utils.py index 2a24e7ef8a..02b1eceacb 100644 --- a/metaflow/runner/utils.py +++ b/metaflow/runner/utils.py @@ -64,3 +64,19 @@ def read_from_file_when_ready( time.sleep(0.1) content = file_pointer.read() return content + + +async def async_read_from_file_when_ready( + file_path: str, command_obj: "CommandManager", timeout: float = 5 +): + import asyncio + + await asyncio.wait_for(command_obj.process.wait(), timeout) + + with open(file_path, "r", encoding="utf-8") as file_pointer: + content = file_pointer.read() + if not content: + raise CalledProcessError( + command_obj.process.returncode, command_obj.command + ) + return content