Skip to content

Commit

Permalink
Merge pull request #1 from emattia/rileyhun/hunr/hpc+ray
Browse files Browse the repository at this point in the history
Add error handling, node watcher, and auto checkpoint_path var
  • Loading branch information
rileyhun authored Aug 23, 2023
2 parents f94763d + c7d94a8 commit 88bb9bb
Showing 1 changed file with 150 additions and 69 deletions.
219 changes: 150 additions & 69 deletions metaflow/plugins/frameworks/ray.py
Original file line number Diff line number Diff line change
@@ -1,101 +1,124 @@
import subprocess
import os
import sys
import time
import json
import signal
import subprocess
from pathlib import Path
from threading import Thread
from metaflow.exception import MetaflowException
from metaflow.unbounded_foreach import UBF_CONTROL
from metaflow.plugins.parallel_decorator import ParallelDecorator, _local_multinode_control_task_step_func

RAY_CHECKPOINT_VAR_NAME = 'checkpoint_path'
RAY_JOB_COMPLETE_VAR = 'ray_job_completed'
RAY_NODE_STARTED_VAR = 'node_started'
CONTROL_TASK_S3_ID = 'control'

class RayParallelDecorator(ParallelDecorator):

name = "ray_parallel"
defaults = {"main_port": None}
defaults = {"main_port": None, "worker_polling_freq": 10, "all_nodes_started_timeout": 90}
IS_PARALLEL = True

def task_decorate(
self, step_func, flow, graph, retry_count, max_user_code_retries, ubf_context
self, step_func, flow, graph, retry_count, max_user_code_retries, ubf_context
):

from functools import partial

def _worker_heartbeat(graph_info=flow._graph_info):
from metaflow import Task, current
control = get_previous_task_pathspec(graph_info, current)
while not Task(control).finished:
time.sleep(3)
from metaflow import S3, current
from metaflow.metaflow_config import DATATOOLS_S3ROOT

def _empty_worker_task():
pass
pass # local case

def _worker_heartbeat(polling_freq=self.attributes["worker_polling_freq"], var=RAY_JOB_COMPLETE_VAR):
while not json.loads(s3.get(CONTROL_TASK_S3_ID).blob)[var]:
time.sleep(polling_freq)

def _control_wrapper(step_func, flow, var=RAY_JOB_COMPLETE_VAR):
watcher = NodeParticipationWatcher(expected_num_nodes=current.num_nodes, polling_freq=10)
try:
step_func()
except Exception as e:
raise ControlTaskException(e)
finally:
watcher.end()
s3.put(CONTROL_TASK_S3_ID, json.dumps({var: True}))

s3 = S3(run=flow)
ensure_ray_installed()
if os.environ.get("METAFLOW_RUNTIME_ENVIRONMENT", "local") == "local":
checkpoint_path = os.path.join(os.getcwd(), "ray_checkpoints")
else:
checkpoint_path = os.path.join(
DATATOOLS_S3ROOT, current.flow_name, current.run_id, "ray_checkpoints"
)
setattr(flow, RAY_CHECKPOINT_VAR_NAME, checkpoint_path)

if os.environ.get("METAFLOW_RUNTIME_ENVIRONMENT", "local") == "local":
if ubf_context == UBF_CONTROL:
env_to_use = getattr(self.environment, "base_env", self.environment)
return partial(
_local_multinode_control_task_step_func,
# assigns the flow._control_mapper_tasks variables & runs worker subprocesses.
flow,
env_to_use,
env_to_use,
step_func,
# run user code and let ray.init() auto-detect available resources. could swap this for an entrypoint.py file to match ray job submission API.
retry_count,
)
return partial(
_empty_worker_task) # don't need to run code on worker task. ray.init() in control attaches driver to the cluster.
)
return partial(_empty_worker_task)
else:
self.setup_distributed_env(flow, ubf_context)
if ubf_context == UBF_CONTROL:
return step_func
return partial(
_worker_heartbeat) # don't need to run user code on worker task. ray.init() in control attaches driver to the cluster.
return partial(_control_wrapper, step_func=step_func, flow=flow)
return partial(_worker_heartbeat)

def setup_distributed_env(self, flow, ubf_context):
self.ensure_ray_air_installed()
ray_cli_path = sys.executable.replace("python", "ray")
setup_ray_distributed(self.attributes["main_port"], ray_cli_path, flow, ubf_context)
py_cli_path = Path(sys.executable)
if py_cli_path.is_symlink():
py_cli_path = os.readlink(py_cli_path)
ray_cli_path = os.path.join(py_cli_path.split('python')[0], 'ray')
setup_ray_distributed(self.attributes["main_port"], self.attributes["all_nodes_started_timeout"], ray_cli_path, flow, ubf_context)

def ensure_ray_air_installed(self):
try:
import ray
except ImportError:
print("Ray is not installed. Installing latest version of ray-air package.")
subprocess.run([sys.executable, "-m", "pip", "install", "-U", "ray[air]"])

def setup_ray_distributed(
main_port=None,
ray_cli_path=None,
run=None,
ubf_context=None
main_port,
all_nodes_started_timeout,
ray_cli_path,
run,
ubf_context
):

import ray
import json
import socket
from metaflow import S3, current

# Why are deco.task_pre_step and deco.task_decorate calls in the same loop?
# https://github.com/Netflix/metaflow/blob/76eee802cba1983dffe7e7731dd8e31e2992e59b/metaflow/task.py#L553
# this causes these current.parallel variables to be defaults on all nodes,
# since AWS Batch decorator's task_pre_step hasn't run yet.
# The way this runs now causes these current.parallel variables to be defaults on all nodes,
# since AWS Batch decorator task_pre_step hasn't run prior to the above task_decorate call.
# num_nodes = current.parallel.num_nodes
# node_index = current.parallel.node_index

# AWS Batch-specific workaround.
num_nodes = int(os.environ["AWS_BATCH_JOB_NUM_NODES"])
node_index = os.environ["AWS_BATCH_JOB_NODE_INDEX"]
node_key = os.path.join("ray_nodes", "node_%s.json" % node_index)
node_key = os.path.join(RAY_NODE_STARTED_VAR, "node_%s.json" % node_index)
current._update_env({'num_nodes': num_nodes})

# Similar to above comment,
# better to use current.parallel.main_ip instead of this conditional block,
# but this seems to require a change to the main loop in metaflow.task.
# Similar to above comment,
# better to use current.parallel.main_ip instead of this conditional block,
# but this seems to require a change to the main loop in metaflow.task.
if ubf_context == UBF_CONTROL:
local_ips = socket.gethostbyname_ex(socket.gethostname())[-1]
main_ip = local_ips[0]
else:
else:
main_ip = os.environ['AWS_BATCH_JOB_MAIN_NODE_PRIVATE_IPV4_ADDRESS']

try:
main_port = main_port or (6379 + abs(int(current.run_id)) % 1000)
except:
# if `int()` fails, i.e. `run_id` is not an `int`, use just a constant port. Can't use `hash()`,
# as that is not constant.
main_port = 6379

s3 = S3(run=run)
Expand All @@ -112,6 +135,7 @@ def setup_ray_distributed(
str(main_port),
]
)
s3.put('control', json.dumps({RAY_JOB_COMPLETE_VAR: False}))
else:
node_ip_address = ray._private.services.get_node_ip_address()
runtime_start_result = subprocess.run(
Expand All @@ -124,52 +148,109 @@ def setup_ray_distributed(
"%s:%s" % (main_ip, main_port),
]
)

if runtime_start_result.returncode != 0:
raise Exception("Ray runtime failed to start on node %s" % node_index)
raise RayWorkerFailedStartException(node_index)
else:
s3.put(node_key, json.dumps({'node_started': True}))

def _num_nodes_started(path="ray_nodes"):
def _num_nodes_started(path=RAY_NODE_STARTED_VAR):
objs = s3.get_recursive([path])
num_started = 0
for obj in objs:
obj = json.loads(obj.text)
if obj['node_started']:
num_started += 1
else:
raise Exception("Node {} failed to start Ray runtime".format(node_index))
raise RayWorkerFailedStartException(node_index)
return num_started

# poll until all workers have joined the cluster
if ubf_context == UBF_CONTROL:
t0 = time.time()
while _num_nodes_started() < num_nodes:
if all_nodes_started_timeout <= time.time() - t0:
raise AllNodesStartupTimeoutException()
time.sleep(10)

s3.close()


def get_previous_task_pathspec(graph_info, current):
"""
Find the pathspec of the control task that a worker task is coupled to.
"""
def ensure_ray_installed():
while True:
try:
import ray
break
except ImportError:
print("Ray is not installed. Installing latest version of ray-air package.")
subprocess.run([sys.executable, "-m", "pip", "install", "-U", "ray[air]"])


class NodeParticipationWatcher(object):

from metaflow import Step
def __init__(self, expected_num_nodes, polling_freq=10, t_user_code_start_buffer=30):
self.t_user_code_start_buffer = t_user_code_start_buffer
self.expected_num_nodes = expected_num_nodes
self.polling_freq = polling_freq
self._thread = Thread(target = self._enforce_participation)
self.is_alive = True
self._thread.start()

steps_info = graph_info['steps']
for step_name, step_info in steps_info.items():
if current.step_name == step_name:
previous_step_name = step_name
step_pathspec = "{flow_name}/{run_id}/{step_name}".format(
flow_name=current.flow_name,
run_id=current.run_id,
step_name=previous_step_name
)
step = Step(step_pathspec)
for task in step:
if task.id.startswith("control"):
control_task_pathspec = "{step_pathspec}/{task_id}".format(
step_pathspec=step.pathspec,
task_id=task.id
)
return control_task_pathspec
def end(self):
self.is_alive = False

def _enforce_participation(self):

import ray

# Why this sleep?
time.sleep(self.t_user_code_start_buffer)
# The user code is expected to run ray.init(), in line with ergonomic Ray workflows.
# To run self._num_nodes_started() in following loop, ray.init() needs to already run.
# If we don't wait for user code to run ray.init(),
# then we need to do it before this loop,
# which causes the user code ray.init() to throw error like:
# `Maybe you called ray.init twice by accident?`
# and will ask user to put 'ignore_reinit_error=True' in 'ray.init()', which is annoying UX.
# So we wait for user code to run ray.init() before we run self._num_nodes_started() in following loop.

while self.is_alive:
n = self._num_nodes(ray)
if n < self.expected_num_nodes:
self.is_alive = False
self._kill_run(n)
time.sleep(self.polling_freq)

def _num_nodes(self, ray):
return len(ray._private.state.state._live_node_ids()) # Should this use ray._private.state.node_ids()?

def _kill_run(self, n):
msg = "Node {} stopped participating. Expected {} nodes to participate.".format(n, self.expected_num_nodes)
print(msg)
os.kill(os.getpid(), signal.SIGINT)


class ControlTaskException(MetaflowException):
headline = "Contral task error"

def __init__(self, e):
msg = """
Spinning down all workers because of the following exception running the @step code on the control task:
{exception_str}
""".format(exception_str=str(e))
super(ControlTaskException, self).__init__(msg)


class RayWorkerFailedStartException(MetaflowException):
headline = "Worker task startup error"

def __init__(self, node_index):
msg = "Worker task failed to start on node {}".format(node_index)
super(RayWorkerFailedStartException, self).__init__(msg)


class AllNodesStartupTimeoutException(MetaflowException):
headline = "All workers did not join cluster error"

def __init__(self):
msg = "Exiting job due to time out waiting for all workers to join cluster. You can set the timeout in @ray_parallel(all_nodes_started_timeout=X)"
super(AllNodesStartupTimeoutException, self).__init__(msg)

0 comments on commit 88bb9bb

Please sign in to comment.