From 653f1e1bb0cdc65f84f0bfd4e3a2ef62fbb9542b Mon Sep 17 00:00:00 2001 From: Valay Dave Date: Wed, 3 Jul 2024 18:44:41 -0700 Subject: [PATCH] [parallel-fixes] core + test changes - fix tests for interplay of @secrets and @parallel - Local runtime will allow secrets only on worker tasks - Secrets will be set on all kinds of tasks when run remotely (control/worker) - fix tests for ubf based on new changes to core. - fix tests for tag-catch for ubf based on new changes to core. - internal ubf decorator has a internal_task_type set to it. - [feedback] register metadata in parallel decorator - [feedback]@parallel inject in current: - move `current.parallel` from `metaflow_current` to `parallel_decorator` - [feedback] appropariately setting task-metadata for parallel stuff - The 'world size' metadata will be set in the @parallel decorator. - The 'node-index' metadata, however, varies depending on the type of computing environment executing the task so it will be set in the appropriate compute decorators. - One reason to specify 'node-index' within compute decorators is that the parallel implementation in the compute decorator might not directly set the required environment variables (`MF_PARALLEL_*`). Instead, these values may be established during the `task_pre_step` phase of the compute decorator using other environment variables set during the implementation. - adding some aws batch changes - [feedback] safety check for _parallel_buf_iter in task_pre_step for @parallel - [feedback] set `is_parallel` in current to denote a step is running under an `@parallel` decorator. --- metaflow/metaflow_current.py | 8 -- metaflow/plugins/aws/batch/batch_decorator.py | 29 ++--- metaflow/plugins/parallel_decorator.py | 109 ++++++++++++++++-- metaflow/plugins/secrets/secrets_decorator.py | 13 ++- .../test_unbounded_foreach_decorator.py | 42 ++++++- metaflow/runtime.py | 60 +++------- test/core/tests/secrets_decorator.py | 10 +- test/core/tests/tag_catch.py | 4 +- 8 files changed, 193 insertions(+), 82 deletions(-) diff --git a/metaflow/metaflow_current.py b/metaflow/metaflow_current.py index 6b89c9a1a2..ecd9730ebb 100644 --- a/metaflow/metaflow_current.py +++ b/metaflow/metaflow_current.py @@ -260,14 +260,6 @@ def username(self) -> Optional[str]: """ return self._username - @property - def parallel(self): - return Parallel( - main_ip=os.environ.get("MF_PARALLEL_MAIN_IP", "127.0.0.1"), - num_nodes=int(os.environ.get("MF_PARALLEL_NUM_NODES", "1")), - node_index=int(os.environ.get("MF_PARALLEL_NODE_INDEX", "0")), - ) - @property def tags(self): """ diff --git a/metaflow/plugins/aws/batch/batch_decorator.py b/metaflow/plugins/aws/batch/batch_decorator.py index 202d1b33bc..599181ba2a 100644 --- a/metaflow/plugins/aws/batch/batch_decorator.py +++ b/metaflow/plugins/aws/batch/batch_decorator.py @@ -261,8 +261,8 @@ def task_pre_step( # metadata. A rudimentary way to detect non-local execution is to # check for the existence of AWS_BATCH_JOB_ID environment variable. + meta = {} if "AWS_BATCH_JOB_ID" in os.environ: - meta = {} meta["aws-batch-job-id"] = os.environ["AWS_BATCH_JOB_ID"] meta["aws-batch-job-attempt"] = os.environ["AWS_BATCH_JOB_ATTEMPT"] meta["aws-batch-ce-name"] = os.environ["AWS_BATCH_CE_NAME"] @@ -290,18 +290,6 @@ def task_pre_step( instance_meta = get_ec2_instance_metadata() meta.update(instance_meta) - entries = [ - MetaDatum( - field=k, - value=v, - type=k, - tags=["attempt_id:{0}".format(retry_count)], - ) - for k, v in meta.items() - ] - # Register book-keeping metadata for debugging. - metadata.register_metadata(run_id, step_name, task_id, entries) - self._save_logs_sidecar = Sidecar("save_logs_periodically") self._save_logs_sidecar.start() @@ -322,6 +310,21 @@ def task_pre_step( if num_parallel >= 1: _setup_multinode_environment() + # current.parallel.node_index will be correctly available over here. + meta.update({"parallel-node-index": current.parallel.node_index}) + + if len(meta) > 0: + entries = [ + MetaDatum( + field=k, + value=v, + type=k, + tags=["attempt_id:{0}".format(retry_count)], + ) + for k, v in meta.items() + ] + # Register book-keeping metadata for debugging. + metadata.register_metadata(run_id, step_name, task_id, entries) def task_finished( self, step_name, flow, graph, is_task_ok, retry_count, max_retries diff --git a/metaflow/plugins/parallel_decorator.py b/metaflow/plugins/parallel_decorator.py index c93549926e..0a25ec6f2f 100644 --- a/metaflow/plugins/parallel_decorator.py +++ b/metaflow/plugins/parallel_decorator.py @@ -1,11 +1,30 @@ +from collections import namedtuple from metaflow.decorators import StepDecorator -from metaflow.unbounded_foreach import UBF_CONTROL +from metaflow.unbounded_foreach import UBF_CONTROL, CONTROL_TASK_TAG from metaflow.exception import MetaflowException +from metaflow.metadata import MetaDatum +from metaflow.metaflow_current import current, Parallel import os import sys class ParallelDecorator(StepDecorator): + """ + MF Add To Current + ----------------- + parallel -> metaflow.metaflow_current.Parallel + + @@ Returns + ------- + `Parallel`: `namedtuple` with the following fields: + - main_ip : str + The IP address of the control task. + - num_nodes : int + The total number of tasks created by @parallel + - node_index : int + The index of the current task in all the @parallel tasks. + """ + name = "parallel" defaults = {} IS_PARALLEL = True @@ -16,7 +35,6 @@ def __init__(self, attributes=None, statically_defined=False): def runtime_step_cli( self, cli_args, retry_count, max_user_code_retries, ubf_context ): - if ubf_context == UBF_CONTROL: num_parallel = cli_args.task.ubf_iter.num_parallel cli_args.command_options["num-parallel"] = str(num_parallel) @@ -25,6 +43,79 @@ def step_init( self, flow, graph, step_name, decorators, environment, flow_datastore, logger ): self.environment = environment + # We choose `setattr` instead of the default `current._update_env` for two reasons: first, to preserve + # the datatype of `current.parallel` as a namedtuple, and second, to ensure the contents of the namedtuple + # are accurately set based on when they are accessed in the Metaflow lifecycle. The values for `main_ip`, + # `num_nodes`, and `node_index` are determined by environment variables which maybe set in `task_pre_step`. + # Thus, we use a `property` to lazily evaluate these namedtuple values. + setattr( + current.__class__, + "parallel", + property( + fget=lambda _: Parallel( + main_ip=os.environ.get("MF_PARALLEL_MAIN_IP", "127.0.0.1"), + num_nodes=int(os.environ.get("MF_PARALLEL_NUM_NODES", "1")), + node_index=int(os.environ.get("MF_PARALLEL_NODE_INDEX", "0")), + ) + ), + ) + + def task_pre_step( + self, + step_name, + task_datastore, + metadata, + run_id, + task_id, + flow, + graph, + retry_count, + max_user_code_retries, + ubf_context, + inputs, + ): + from metaflow import current + + # Set `is_parallel` to `True` in `current` just like we + # with `is_production` in the project decorator. + current._update_env( + { + "is_parallel": True, + } + ) + + self.input_paths = [obj.pathspec for obj in inputs] + if not hasattr(flow, "_parallel_ubf_iter"): + raise MetaflowException( + "Parallel decorator is only supported in unbounded foreach steps." + ) + task_metadata_list = [ + MetaDatum( + field="parallel-world-size", + value=flow._parallel_ubf_iter.num_parallel, + type="parallel-world-size", + tags=["attempt_id:{0}".format(0)], + ) + ] + if ubf_context == UBF_CONTROL: + # A Task's tags are now those of its ancestral Run, so we are not able + # to rely on a task's tags to indicate the presence of a control task + # so, on top of adding the tags above, we also add a task metadata + # entry indicating that this is a "control task". + # + # Here we will also add a task metadata entry to indicate "control + # task". Within the metaflow repo, the only dependency of such a + # "control task" indicator is in the integration test suite (see + # Step.control_tasks() in client API). + task_metadata_list += [ + MetaDatum( + field="internal_task_type", + value=CONTROL_TASK_TAG, + type="internal_task_type", + tags=["attempt_id:{0}".format(0)], + ) + ] + metadata.register_metadata(run_id, step_name, task_id, task_metadata_list) def task_decorate( self, step_func, flow, graph, retry_count, max_user_code_retries, ubf_context @@ -47,6 +138,7 @@ def _step_func_with_setup(): env_to_use, _step_func_with_setup, retry_count, + ",".join(self.input_paths), ) else: return _step_func_with_setup @@ -56,7 +148,9 @@ def setup_distributed_env(self, flow): pass -def _local_multinode_control_task_step_func(flow, env_to_use, step_func, retry_count): +def _local_multinode_control_task_step_func( + flow, env_to_use, step_func, retry_count, input_paths +): """ Used as multinode UBF control task when run in local mode. """ @@ -80,10 +174,7 @@ def _local_multinode_control_task_step_func(flow, env_to_use, step_func, retry_c run_id = current.run_id step_name = current.step_name control_task_id = current.task_id - - (_, split_step_name, split_task_id) = control_task_id.split("-")[1:] # UBF handling for multinode case - top_task_id = control_task_id.replace("control-", "") # chop "-0" mapper_task_ids = [control_task_id] # If we are running inside Conda, we use the base executable FIRST; # the conda environment will then be used when runtime_step_cli is @@ -93,12 +184,13 @@ def _local_multinode_control_task_step_func(flow, env_to_use, step_func, retry_c script = sys.argv[0] # start workers + # TODO: Logs for worker processes are assigned to control process as of today, which + # should be fixed at some point subprocesses = [] for node_index in range(1, num_parallel): - task_id = "%s_node_%d" % (top_task_id, node_index) + task_id = "%s_node_%d" % (control_task_id, node_index) mapper_task_ids.append(task_id) os.environ["MF_PARALLEL_NODE_INDEX"] = str(node_index) - input_paths = "%s/%s/%s" % (run_id, split_step_name, split_task_id) # Override specific `step` kwargs. kwargs = cli_args.step_kwargs kwargs["split_index"] = str(node_index) @@ -109,6 +201,7 @@ def _local_multinode_control_task_step_func(flow, env_to_use, step_func, retry_c kwargs["retry_count"] = str(retry_count) cmd = cli_args.step_command(executable, script, step_name, step_kwargs=kwargs) + p = subprocess.Popen(cmd) subprocesses.append(p) diff --git a/metaflow/plugins/secrets/secrets_decorator.py b/metaflow/plugins/secrets/secrets_decorator.py index 641711490f..a0638fe103 100644 --- a/metaflow/plugins/secrets/secrets_decorator.py +++ b/metaflow/plugins/secrets/secrets_decorator.py @@ -210,8 +210,17 @@ def task_pre_step( ubf_context, inputs, ): - if ubf_context == UBF_CONTROL: - """control tasks (as used in "unbounded for each") don't need secrets""" + # We will skip the secret injection for "locally" launched UBF_CONTROL tasks because otherwise there + # will be a env var clash with UBF_WORKER tasks which will fail the flow in the @secrets decorator's `task_pre_step`. + if ( + ubf_context == UBF_CONTROL + and os.environ.get("METAFLOW_RUNTIME_ENVIRONMENT", "local") == "local" + ): + # When we "locally" run @parallel tasks, the control task will create the worker tasks and the environment variables + # of the control task are inherited by the worker tasks. If we don't skip setting secrets in the control task then the + # worker tasks will already have the secrets set as environment variables, causing the @secrets' `task_pre_step` to fail. + # In remote settings, (e.g. AWS Batch/Kubernetes), the worker task and control task are independently created + # so there is no chances of an env var clash. return # List of pairs (secret_spec, env_vars_from_this_spec) all_secrets_env_vars = [] diff --git a/metaflow/plugins/test_unbounded_foreach_decorator.py b/metaflow/plugins/test_unbounded_foreach_decorator.py index e5f2962fa3..116c30b149 100644 --- a/metaflow/plugins/test_unbounded_foreach_decorator.py +++ b/metaflow/plugins/test_unbounded_foreach_decorator.py @@ -8,8 +8,14 @@ from metaflow.cli_args import cli_args from metaflow.decorators import StepDecorator from metaflow.exception import MetaflowException -from metaflow.unbounded_foreach import UnboundedForeachInput, UBF_CONTROL, UBF_TASK +from metaflow.unbounded_foreach import ( + UnboundedForeachInput, + UBF_CONTROL, + UBF_TASK, + CONTROL_TASK_TAG, +) from metaflow.util import to_unicode +from metaflow.metadata import MetaDatum class InternalTestUnboundedForeachInput(UnboundedForeachInput): @@ -60,13 +66,43 @@ def step_init( ): self.environment = environment + def task_pre_step( + self, + step_name, + task_datastore, + metadata, + run_id, + task_id, + flow, + graph, + retry_count, + max_user_code_retries, + ubf_context, + inputs, + ): + if ubf_context == UBF_CONTROL: + metadata.register_metadata( + run_id, + step_name, + task_id, + [ + MetaDatum( + field="internal_task_type", + value=CONTROL_TASK_TAG, + type="internal_task_type", + tags=["attempt_id:{0}".format(0)], + ) + ], + ) + self.input_paths = [obj.pathspec for obj in inputs] + def control_task_step_func(self, flow, graph, retry_count): from metaflow import current run_id = current.run_id step_name = current.step_name control_task_id = current.task_id - (_, split_step_name, split_task_id) = control_task_id.split("-")[1:] + # (_, split_step_name, split_task_id) = control_task_id.split("-")[1:] # If we are running inside Conda, we use the base executable FIRST; # the conda environment will then be used when runtime_step_cli is # called. This is so that it can properly set up all the metaflow @@ -97,7 +133,7 @@ def control_task_step_func(self, flow, graph, retry_count): task_id = "%s-%d" % (control_task_id.replace("control-", "test-ubf-"), i) pathspec = "%s/%s/%s" % (run_id, step_name, task_id) mapper_tasks.append(to_unicode(pathspec)) - input_paths = "%s/%s/%s" % (run_id, split_step_name, split_task_id) + input_paths = ",".join(self.input_paths) # Override specific `step` kwargs. kwargs = cli_args.step_kwargs diff --git a/metaflow/runtime.py b/metaflow/runtime.py index 51880cb1f4..e5fb2e08a8 100644 --- a/metaflow/runtime.py +++ b/metaflow/runtime.py @@ -1108,56 +1108,24 @@ def mark_resume_done(self): def _get_task_id(self, task_id): already_existed = True + tags = [] if self.ubf_context == UBF_CONTROL: - [input_path] = self.input_paths - run, input_step, input_task = input_path.split("/") - # We associate the control task-id to be 1:1 with the split node - # where the unbounded-foreach was defined. - # We prefer encoding the corresponding split into the task_id of - # the control node; so it has access to this information quite - # easily. There is anyway a corresponding int id stored in the - # metadata backend - so this should be fine. - task_id = "control-%s-%s-%s" % (run, input_step, input_task) - # Register only regular Metaflow (non control) tasks. + tags = [CONTROL_TASK_TAG] + # Register Metaflow tasks. if task_id is None: - task_id = str(self.metadata.new_task_id(self.run_id, self.step)) + task_id = str( + self.metadata.new_task_id(self.run_id, self.step, sys_tags=tags) + ) already_existed = False else: - # task_id is preset only by persist_constants() or control tasks. - if self.ubf_context == UBF_CONTROL: - tags = [CONTROL_TASK_TAG] - attempt_id = 0 - already_existed = not self.metadata.register_task_id( - self.run_id, - self.step, - task_id, - attempt_id, - sys_tags=tags, - ) - # A Task's tags are now those of its ancestral Run, so we are not able - # to rely on a task's tags to indicate the presence of a control task - # so, on top of adding the tags above, we also add a task metadata - # entry indicating that this is a "control task". - # - # Here we will also add a task metadata entry to indicate "control task". - # Within the metaflow repo, the only dependency of such a "control task" - # indicator is in the integration test suite (see Step.control_tasks() in - # client API). - task_metadata_list = [ - MetaDatum( - field="internal_task_type", - value=CONTROL_TASK_TAG, - type="internal_task_type", - tags=["attempt_id:{0}".format(attempt_id)], - ) - ] - self.metadata.register_metadata( - self.run_id, self.step, task_id, task_metadata_list - ) - else: - already_existed = not self.metadata.register_task_id( - self.run_id, self.step, task_id, 0 - ) + # task_id is preset only by persist_constants(). + already_existed = not self.metadata.register_task_id( + self.run_id, + self.step, + task_id, + 0, + sys_tags=tags, + ) self.task_id = task_id self._path = "%s/%s/%s" % (self.run_id, self.step, self.task_id) diff --git a/test/core/tests/secrets_decorator.py b/test/core/tests/secrets_decorator.py index 0f0644ff93..5b41f76eb7 100644 --- a/test/core/tests/secrets_decorator.py +++ b/test/core/tests/secrets_decorator.py @@ -26,7 +26,15 @@ def step_all(self): import os from metaflow import current - if current.task_id.startswith("control-"): + if ( + self._graph[current.step_name].parallel_step + and current.parallel.node_index == 0 + and os.environ.get("METAFLOW_RUNTIME_ENVIRONMENT", "local") == "local" + ): + # We don't check control task secrets when there is a parallel step + # run locally. + # todo (future): support the case where secrets need to be passsed to the + # control task in a parallel step when run locally. return assert os.environ.get("secret_1") == "Pizza is a vegetable." diff --git a/test/core/tests/tag_catch.py b/test/core/tests/tag_catch.py index a326a46684..a8efb91956 100644 --- a/test/core/tests/tag_catch.py +++ b/test/core/tests/tag_catch.py @@ -121,7 +121,9 @@ def check_results(self, flow, checker): data = task.data got = sorted(m.value for m in task.metadata if m.type == "attempt") if flow._graph[step.id].parallel_step: - if "control" in task.id: + if task.metadata_dict.get( + "internal_task_type", None + ): # Only control tasks have internal_task_type set assert_equals(list(map(str, range(attempts))), got) else: # non-control tasks have one attempt less for parallel steps