Skip to content

Commit

Permalink
sfn runner using Deployer
Browse files Browse the repository at this point in the history
  • Loading branch information
madhur-ob committed Jun 3, 2024
1 parent 19e3ba3 commit e12fa0b
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 171 deletions.
27 changes: 24 additions & 3 deletions metaflow/plugins/aws/step_functions/step_functions_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,13 @@ def step_functions(obj, name=None):
help="Use AWS Step Functions Distributed Map instead of Inline Map for "
"defining foreach tasks in Amazon State Language.",
)
@click.option(
"--runner-attribute-file",
default=None,
show_default=True,
type=str,
help="Write the workflow name to the file specified. Used internally for Metaflow's Runner API.",
)
@click.pass_obj
def create(
obj,
Expand All @@ -144,9 +151,14 @@ def create(
workflow_timeout=None,
log_execution_history=False,
use_distributed_map=False,
runner_attribute_file=None,
):
validate_tags(tags)

if runner_attribute_file:
with open(runner_attribute_file, "w") as f:
json.dump({"name": obj.state_machine_name}, f)

obj.echo(
"Deploying *%s* to AWS Step Functions..." % obj.state_machine_name, bold=True
)
Expand Down Expand Up @@ -232,8 +244,10 @@ def check_metadata_service_version(obj):


def resolve_state_machine_name(obj, name):
def attach_prefix(name):
if SFN_STATE_MACHINE_PREFIX is not None:
def attach_prefix(name: str):
if SFN_STATE_MACHINE_PREFIX is not None and (
not name.startswith(SFN_STATE_MACHINE_PREFIX)
):
return SFN_STATE_MACHINE_PREFIX + "_" + name
return name

Expand Down Expand Up @@ -476,7 +490,14 @@ def _convert_value(param):

if runner_attribute_file:
with open(runner_attribute_file, "w") as f:
f.write("%s:%s" % (get_metadata(), "/".join((obj.flow.name, run_id))))
json.dump(
{
"name": obj.state_machine_name,
"metadata": get_metadata(),
"pathspec": "/".join((obj.flow.name, run_id)),
},
f,
)

obj.echo(
"Workflow *{name}* triggered on AWS Step Functions "
Expand Down
219 changes: 51 additions & 168 deletions metaflow/plugins/aws/step_functions/step_functions_runner.py
Original file line number Diff line number Diff line change
@@ -1,191 +1,74 @@
import os
import sys
import tempfile
from typing import Optional, Dict

from metaflow import Run, metadata
from metaflow.exception import MetaflowNotFound
from metaflow.runner.subprocess_manager import CommandManager, SubprocessManager
from metaflow.runner.utils import clear_and_set_os_environ, read_from_file_when_ready


def get_lower_level_sfn_group(api, top_level_kwargs: Dict, name: Optional[str]):
if name is None:
return getattr(api(**top_level_kwargs), "step-functions")()
return getattr(api(**top_level_kwargs), "step-functions")(name=name)


class StepFunctionsExecutingRun(object):
def __init__(
self,
workflows_template_obj: "StepFunctionsTemplate",
runner_attribute_file_content: str,
):
self.workflows_template_obj = workflows_template_obj
self.runner = self.workflows_template_obj.runner
self.metadata_for_flow, self.pathspec = runner_attribute_file_content.rsplit(
":", maxsplit=1
)

@property
def run(self):
clear_and_set_os_environ(self.runner.old_env)
metadata(self.metadata_for_flow)
from typing import Optional, ClassVar
from metaflow.plugins.aws.step_functions.step_functions import StepFunctions
from metaflow.plugins.deployer import (
Deployer,
DeployedFlow,
TriggeredRun,
)

try:
return Run(self.pathspec, _namespace_check=False)
except MetaflowNotFound:
raise MetaflowNotFound(
"Run object not available yet, Please try again in a bit.."
)

class StepFunctionExecutingRun(TriggeredRun):
@property
def status(self):
raise NotImplementedError

def terminate(self, **kwargs):
_, run_id = self.pathspec.split("/")
command = get_lower_level_sfn_group(
self.runner.api, self.runner.top_level_kwargs, self.runner.name
).terminate(run_id=run_id, **kwargs)

pid = self.runner.spm.run_command(
[sys.executable, *command],
env=self.runner.env_vars,
cwd=self.runner.cwd,
show_output=self.runner.show_output,
)

command_obj = self.runner.spm.get(pid)
return command_obj.process.returncode == 0

class StepFunctionsStateMachine(DeployedFlow):
@property
def production_token(self):
_, production_token = StepFunctions.get_existing_deployment(self.deployer.name)
return production_token

class StepFunctionsTemplate(object):
def __init__(
self,
runner: "StepFunctionsRunner",
):
self.runner = runner

@staticmethod
def from_deployment(name):
# TODO: get the StepFunctionsTemplate object somehow from already deployed step-function, referenced by name
raise NotImplementedError
class StepFunctionsDeployer(Deployer):
type: ClassVar[Optional[str]] = "step-functions"

@property
def production_token(self):
# TODO: how to get this?
raise NotImplementedError
def create(self, **kwargs) -> DeployedFlow:
command_obj = super().create(**kwargs)

def __get_executing_sfn(self, tfp_runner_attribute, command_obj: CommandManager):
try:
content = read_from_file_when_ready(tfp_runner_attribute.name, timeout=10)
return StepFunctionsExecutingRun(self, content)
except TimeoutError as e:
stdout_log = open(command_obj.log_files["stdout"]).read()
stderr_log = open(command_obj.log_files["stderr"]).read()
command = " ".join(command_obj.command)
error_message = "Error executing: '%s':\n" % command
if stdout_log.strip():
error_message += "\nStdout:\n%s\n" % stdout_log
if stderr_log.strip():
error_message += "\nStderr:\n%s\n" % stderr_log
raise RuntimeError(error_message) from e

def trigger(self, **kwargs):
with tempfile.TemporaryDirectory() as temp_dir:
tfp_runner_attribute = tempfile.NamedTemporaryFile(
dir=temp_dir, delete=False
)
command = get_lower_level_sfn_group(
self.runner.api, self.runner.top_level_kwargs, self.runner.name
).trigger(runner_attribute_file=tfp_runner_attribute.name, **kwargs)

pid = self.runner.spm.run_command(
[sys.executable, *command],
env=self.runner.env_vars,
cwd=self.runner.cwd,
show_output=self.runner.show_output,
)

command_obj = self.runner.spm.get(pid)
return self.__get_executing_sfn(tfp_runner_attribute, command_obj)


class StepFunctionsRunner(object):
def __init__(
self,
flow_file: str,
name: Optional[str] = None,
show_output: bool = False,
profile: Optional[str] = None,
env: Optional[Dict] = None,
cwd: Optional[str] = None,
**kwargs
):
from metaflow.cli import start
from metaflow.runner.click_api import MetaflowAPI

self.flow_file = flow_file

# TODO: if we don't supply it, it should default to flow name..
# which it does internally behind the scenes in CLI, but that isn't reflected here.
# This is so that we can use StepFunctionsTemplate.from_deployment(name=StepFunctionsRunner("../try.py").name)
self.name = name
self.show_output = show_output

self.old_env = os.environ.copy()
self.env_vars = self.old_env.copy()
self.env_vars.update(env or {})
if profile:
self.env_vars["METAFLOW_PROFILE"] = profile

self.cwd = cwd
self.spm = SubprocessManager()
self.top_level_kwargs = kwargs
self.api = MetaflowAPI.from_cli(self.flow_file, start)

def __enter__(self) -> "StepFunctionsRunner":
return self

def create(self, **kwargs):
command = get_lower_level_sfn_group(
self.api, self.top_level_kwargs, self.name
).create(**kwargs)
pid = self.spm.run_command(
[sys.executable, *command],
env=self.env_vars,
cwd=self.cwd,
show_output=self.show_output,
)
command_obj = self.spm.get(pid)
if command_obj.process.returncode == 0:
return StepFunctionsTemplate(runner=self)
raise Exception("Error deploying %s to Step Functions" % self.flow_file)
return StepFunctionsStateMachine(deployer=self)

raise Exception("Error deploying %s to %s" % (self.flow_file, self.type))

def __exit__(self, exc_type, exc_value, traceback):
self.spm.cleanup()
def trigger(self, **kwargs) -> TriggeredRun:
content, command_obj = super().trigger(**kwargs)

if command_obj.process.returncode == 0:
return StepFunctionExecutingRun(self, content)

def cleanup(self):
self.spm.cleanup()
raise Exception(
"Error triggering %s on %s for %s" % (self.name, self.type, self.flow_file)
)


if __name__ == "__main__":
import time

ar = StepFunctionsRunner("../try.py")
ar = StepFunctionsDeployer("../try.py")
print(ar.name)
ar_obj = ar.deploy()
print(ar.name)
ar_obj = ar.create()
print(type(ar))
print(type(ar_obj))
print(ar_obj.production_token)
result = ar_obj.trigger(alpha=300)
# print("aaa", result.status)
while True:
try:
print(result.run)
break # Exit the loop if the run object is found
except MetaflowNotFound:
print("didn't get the run object yet...")
time.sleep(5) # Wait for 5 seconds before retrying
print(result.run.id)
# print("bbb", result.status)
run = result.run
while run is None:
print("trying again...")
run = result.run
print(result.run)
time.sleep(120)
print(result.terminate())

print("triggering from deployer..")
result = ar.trigger(alpha=600)
print(result.name)
run = result.run
while run is None:
print("trying again...")
time.sleep(5)
run = result.run
print(result.run)
time.sleep(120)
print(result.terminate())

0 comments on commit e12fa0b

Please sign in to comment.