Skip to content

Commit

Permalink
[parallel-fixes] core + test changes
Browse files Browse the repository at this point in the history
- fix tests for interplay of @Secrets and @parallel
- Local runtime will allow secrets only on worker tasks
- Secrets will be set on all kinds of tasks when run remotely (control/worker)
- fix tests for ubf based on new changes to core.
- fix tests for tag-catch for ubf based on new changes to core.
- internal ubf decorator has a internal_task_type set to it.
- [feedback] register metadata in parallel decorator
- [feedback]@parallel inject in current:
    - move `current.parallel` from `metaflow_current` to `parallel_decorator`
- [feedback] appropariately setting task-metadata for parallel stuff
    - The 'world size' metadata will be set in the @parallel decorator.
    - The 'node-index' metadata, however, varies depending on the type of computing environment executing the task so it will be set in the appropriate compute decorators.
    - One reason to specify 'node-index' within compute decorators is that the parallel implementation in the compute decorator might not directly set the required environment variables (`MF_PARALLEL_*`). Instead, these values may be established during the `task_pre_step` phase of the compute decorator using other environment variables set during the implementation.
    - adding some aws batch changes
- [feedback] safety check for _parallel_buf_iter in task_pre_step for @parallel
- [feedback] set `is_parallel` in current to denote a step is running under an `@parallel` decorator.
  • Loading branch information
valayDave committed Jul 4, 2024
1 parent 2232d0a commit 653f1e1
Show file tree
Hide file tree
Showing 8 changed files with 193 additions and 82 deletions.
8 changes: 0 additions & 8 deletions metaflow/metaflow_current.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,14 +260,6 @@ def username(self) -> Optional[str]:
"""
return self._username

@property
def parallel(self):
return Parallel(
main_ip=os.environ.get("MF_PARALLEL_MAIN_IP", "127.0.0.1"),
num_nodes=int(os.environ.get("MF_PARALLEL_NUM_NODES", "1")),
node_index=int(os.environ.get("MF_PARALLEL_NODE_INDEX", "0")),
)

@property
def tags(self):
"""
Expand Down
29 changes: 16 additions & 13 deletions metaflow/plugins/aws/batch/batch_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,8 @@ def task_pre_step(
# metadata. A rudimentary way to detect non-local execution is to
# check for the existence of AWS_BATCH_JOB_ID environment variable.

meta = {}
if "AWS_BATCH_JOB_ID" in os.environ:
meta = {}
meta["aws-batch-job-id"] = os.environ["AWS_BATCH_JOB_ID"]
meta["aws-batch-job-attempt"] = os.environ["AWS_BATCH_JOB_ATTEMPT"]
meta["aws-batch-ce-name"] = os.environ["AWS_BATCH_CE_NAME"]
Expand Down Expand Up @@ -290,18 +290,6 @@ def task_pre_step(
instance_meta = get_ec2_instance_metadata()
meta.update(instance_meta)

entries = [
MetaDatum(
field=k,
value=v,
type=k,
tags=["attempt_id:{0}".format(retry_count)],
)
for k, v in meta.items()
]
# Register book-keeping metadata for debugging.
metadata.register_metadata(run_id, step_name, task_id, entries)

self._save_logs_sidecar = Sidecar("save_logs_periodically")
self._save_logs_sidecar.start()

Expand All @@ -322,6 +310,21 @@ def task_pre_step(

if num_parallel >= 1:
_setup_multinode_environment()
# current.parallel.node_index will be correctly available over here.
meta.update({"parallel-node-index": current.parallel.node_index})

if len(meta) > 0:
entries = [
MetaDatum(
field=k,
value=v,
type=k,
tags=["attempt_id:{0}".format(retry_count)],
)
for k, v in meta.items()
]
# Register book-keeping metadata for debugging.
metadata.register_metadata(run_id, step_name, task_id, entries)

def task_finished(
self, step_name, flow, graph, is_task_ok, retry_count, max_retries
Expand Down
109 changes: 101 additions & 8 deletions metaflow/plugins/parallel_decorator.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,30 @@
from collections import namedtuple
from metaflow.decorators import StepDecorator
from metaflow.unbounded_foreach import UBF_CONTROL
from metaflow.unbounded_foreach import UBF_CONTROL, CONTROL_TASK_TAG
from metaflow.exception import MetaflowException
from metaflow.metadata import MetaDatum
from metaflow.metaflow_current import current, Parallel
import os
import sys


class ParallelDecorator(StepDecorator):
"""
MF Add To Current
-----------------
parallel -> metaflow.metaflow_current.Parallel
@@ Returns
-------
`Parallel`: `namedtuple` with the following fields:
- main_ip : str
The IP address of the control task.
- num_nodes : int
The total number of tasks created by @parallel
- node_index : int
The index of the current task in all the @parallel tasks.
"""

name = "parallel"
defaults = {}
IS_PARALLEL = True
Expand All @@ -16,7 +35,6 @@ def __init__(self, attributes=None, statically_defined=False):
def runtime_step_cli(
self, cli_args, retry_count, max_user_code_retries, ubf_context
):

if ubf_context == UBF_CONTROL:
num_parallel = cli_args.task.ubf_iter.num_parallel
cli_args.command_options["num-parallel"] = str(num_parallel)
Expand All @@ -25,6 +43,79 @@ def step_init(
self, flow, graph, step_name, decorators, environment, flow_datastore, logger
):
self.environment = environment
# We choose `setattr` instead of the default `current._update_env` for two reasons: first, to preserve
# the datatype of `current.parallel` as a namedtuple, and second, to ensure the contents of the namedtuple
# are accurately set based on when they are accessed in the Metaflow lifecycle. The values for `main_ip`,
# `num_nodes`, and `node_index` are determined by environment variables which maybe set in `task_pre_step`.
# Thus, we use a `property` to lazily evaluate these namedtuple values.
setattr(
current.__class__,
"parallel",
property(
fget=lambda _: Parallel(
main_ip=os.environ.get("MF_PARALLEL_MAIN_IP", "127.0.0.1"),
num_nodes=int(os.environ.get("MF_PARALLEL_NUM_NODES", "1")),
node_index=int(os.environ.get("MF_PARALLEL_NODE_INDEX", "0")),
)
),
)

def task_pre_step(
self,
step_name,
task_datastore,
metadata,
run_id,
task_id,
flow,
graph,
retry_count,
max_user_code_retries,
ubf_context,
inputs,
):
from metaflow import current

# Set `is_parallel` to `True` in `current` just like we
# with `is_production` in the project decorator.
current._update_env(
{
"is_parallel": True,
}
)

self.input_paths = [obj.pathspec for obj in inputs]
if not hasattr(flow, "_parallel_ubf_iter"):
raise MetaflowException(
"Parallel decorator is only supported in unbounded foreach steps."
)
task_metadata_list = [
MetaDatum(
field="parallel-world-size",
value=flow._parallel_ubf_iter.num_parallel,
type="parallel-world-size",
tags=["attempt_id:{0}".format(0)],
)
]
if ubf_context == UBF_CONTROL:
# A Task's tags are now those of its ancestral Run, so we are not able
# to rely on a task's tags to indicate the presence of a control task
# so, on top of adding the tags above, we also add a task metadata
# entry indicating that this is a "control task".
#
# Here we will also add a task metadata entry to indicate "control
# task". Within the metaflow repo, the only dependency of such a
# "control task" indicator is in the integration test suite (see
# Step.control_tasks() in client API).
task_metadata_list += [
MetaDatum(
field="internal_task_type",
value=CONTROL_TASK_TAG,
type="internal_task_type",
tags=["attempt_id:{0}".format(0)],
)
]
metadata.register_metadata(run_id, step_name, task_id, task_metadata_list)

def task_decorate(
self, step_func, flow, graph, retry_count, max_user_code_retries, ubf_context
Expand All @@ -47,6 +138,7 @@ def _step_func_with_setup():
env_to_use,
_step_func_with_setup,
retry_count,
",".join(self.input_paths),
)
else:
return _step_func_with_setup
Expand All @@ -56,7 +148,9 @@ def setup_distributed_env(self, flow):
pass


def _local_multinode_control_task_step_func(flow, env_to_use, step_func, retry_count):
def _local_multinode_control_task_step_func(
flow, env_to_use, step_func, retry_count, input_paths
):
"""
Used as multinode UBF control task when run in local mode.
"""
Expand All @@ -80,10 +174,7 @@ def _local_multinode_control_task_step_func(flow, env_to_use, step_func, retry_c
run_id = current.run_id
step_name = current.step_name
control_task_id = current.task_id

(_, split_step_name, split_task_id) = control_task_id.split("-")[1:]
# UBF handling for multinode case
top_task_id = control_task_id.replace("control-", "") # chop "-0"
mapper_task_ids = [control_task_id]
# If we are running inside Conda, we use the base executable FIRST;
# the conda environment will then be used when runtime_step_cli is
Expand All @@ -93,12 +184,13 @@ def _local_multinode_control_task_step_func(flow, env_to_use, step_func, retry_c
script = sys.argv[0]

# start workers
# TODO: Logs for worker processes are assigned to control process as of today, which
# should be fixed at some point
subprocesses = []
for node_index in range(1, num_parallel):
task_id = "%s_node_%d" % (top_task_id, node_index)
task_id = "%s_node_%d" % (control_task_id, node_index)
mapper_task_ids.append(task_id)
os.environ["MF_PARALLEL_NODE_INDEX"] = str(node_index)
input_paths = "%s/%s/%s" % (run_id, split_step_name, split_task_id)
# Override specific `step` kwargs.
kwargs = cli_args.step_kwargs
kwargs["split_index"] = str(node_index)
Expand All @@ -109,6 +201,7 @@ def _local_multinode_control_task_step_func(flow, env_to_use, step_func, retry_c
kwargs["retry_count"] = str(retry_count)

cmd = cli_args.step_command(executable, script, step_name, step_kwargs=kwargs)

p = subprocess.Popen(cmd)
subprocesses.append(p)

Expand Down
13 changes: 11 additions & 2 deletions metaflow/plugins/secrets/secrets_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,17 @@ def task_pre_step(
ubf_context,
inputs,
):
if ubf_context == UBF_CONTROL:
"""control tasks (as used in "unbounded for each") don't need secrets"""
# We will skip the secret injection for "locally" launched UBF_CONTROL tasks because otherwise there
# will be a env var clash with UBF_WORKER tasks which will fail the flow in the @secrets decorator's `task_pre_step`.
if (
ubf_context == UBF_CONTROL
and os.environ.get("METAFLOW_RUNTIME_ENVIRONMENT", "local") == "local"
):
# When we "locally" run @parallel tasks, the control task will create the worker tasks and the environment variables
# of the control task are inherited by the worker tasks. If we don't skip setting secrets in the control task then the
# worker tasks will already have the secrets set as environment variables, causing the @secrets' `task_pre_step` to fail.
# In remote settings, (e.g. AWS Batch/Kubernetes), the worker task and control task are independently created
# so there is no chances of an env var clash.
return
# List of pairs (secret_spec, env_vars_from_this_spec)
all_secrets_env_vars = []
Expand Down
42 changes: 39 additions & 3 deletions metaflow/plugins/test_unbounded_foreach_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,14 @@
from metaflow.cli_args import cli_args
from metaflow.decorators import StepDecorator
from metaflow.exception import MetaflowException
from metaflow.unbounded_foreach import UnboundedForeachInput, UBF_CONTROL, UBF_TASK
from metaflow.unbounded_foreach import (
UnboundedForeachInput,
UBF_CONTROL,
UBF_TASK,
CONTROL_TASK_TAG,
)
from metaflow.util import to_unicode
from metaflow.metadata import MetaDatum


class InternalTestUnboundedForeachInput(UnboundedForeachInput):
Expand Down Expand Up @@ -60,13 +66,43 @@ def step_init(
):
self.environment = environment

def task_pre_step(
self,
step_name,
task_datastore,
metadata,
run_id,
task_id,
flow,
graph,
retry_count,
max_user_code_retries,
ubf_context,
inputs,
):
if ubf_context == UBF_CONTROL:
metadata.register_metadata(
run_id,
step_name,
task_id,
[
MetaDatum(
field="internal_task_type",
value=CONTROL_TASK_TAG,
type="internal_task_type",
tags=["attempt_id:{0}".format(0)],
)
],
)
self.input_paths = [obj.pathspec for obj in inputs]

def control_task_step_func(self, flow, graph, retry_count):
from metaflow import current

run_id = current.run_id
step_name = current.step_name
control_task_id = current.task_id
(_, split_step_name, split_task_id) = control_task_id.split("-")[1:]
# (_, split_step_name, split_task_id) = control_task_id.split("-")[1:]
# If we are running inside Conda, we use the base executable FIRST;
# the conda environment will then be used when runtime_step_cli is
# called. This is so that it can properly set up all the metaflow
Expand Down Expand Up @@ -97,7 +133,7 @@ def control_task_step_func(self, flow, graph, retry_count):
task_id = "%s-%d" % (control_task_id.replace("control-", "test-ubf-"), i)
pathspec = "%s/%s/%s" % (run_id, step_name, task_id)
mapper_tasks.append(to_unicode(pathspec))
input_paths = "%s/%s/%s" % (run_id, split_step_name, split_task_id)
input_paths = ",".join(self.input_paths)

# Override specific `step` kwargs.
kwargs = cli_args.step_kwargs
Expand Down
Loading

0 comments on commit 653f1e1

Please sign in to comment.