Skip to content

Commit

Permalink
EFA with pre-commit formatting (#1584)
Browse files Browse the repository at this point in the history
* add elastic fabric adapter field in batch decorator

* remove RayDecorator as out of scope for the PR

* add efa support to step-functions

* remove ray_parallel

* fix for multiple EFA devices'

* handle single node use case

* add comments explaining difference in EFA job definition

* do linting

* pass pre-commit checks

---------

Co-authored-by: Riley Hun <riley.hun@autodesk.com>
Co-authored-by: Sakari Ikonen <sakari.a.ikonen@gmail.com>
Co-authored-by: Riley Hun <rileyhun@hotmail.com>
  • Loading branch information
4 people authored Oct 11, 2023
1 parent f8c9b72 commit b76e512
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 0 deletions.
5 changes: 5 additions & 0 deletions metaflow/plugins/aws/batch/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def create_job(
max_swap=None,
swappiness=None,
inferentia=None,
efa=None,
env={},
attrs={},
host_volumes=None,
Expand Down Expand Up @@ -217,6 +218,7 @@ def create_job(
.max_swap(max_swap)
.swappiness(swappiness)
.inferentia(inferentia)
.efa(efa)
.timeout_in_secs(run_time_limit)
.job_def(
image,
Expand All @@ -227,6 +229,7 @@ def create_job(
max_swap,
swappiness,
inferentia,
efa,
memory=memory,
host_volumes=host_volumes,
use_tmpfs=use_tmpfs,
Expand Down Expand Up @@ -336,6 +339,7 @@ def launch_job(
max_swap=None,
swappiness=None,
inferentia=None,
efa=None,
host_volumes=None,
use_tmpfs=None,
tmpfs_tempdir=None,
Expand Down Expand Up @@ -371,6 +375,7 @@ def launch_job(
max_swap,
swappiness,
inferentia,
efa,
env=env,
attrs=attrs,
host_volumes=host_volumes,
Expand Down
9 changes: 9 additions & 0 deletions metaflow/plugins/aws/batch/batch_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,13 @@ def kill(ctx, run_id, user, my_runs):
@click.option("--max-swap", help="Max Swap requirement for AWS Batch.")
@click.option("--swappiness", help="Swappiness requirement for AWS Batch.")
@click.option("--inferentia", help="Inferentia requirement for AWS Batch.")
@click.option(
"--efa",
default=0,
type=int,
help="Activate designated number of elastic fabric adapter devices. "
"EFA driver must be installed and instance type compatible with EFA",
)
@click.option("--use-tmpfs", is_flag=True, help="tmpfs requirement for AWS Batch.")
@click.option("--tmpfs-tempdir", is_flag=True, help="tmpfs requirement for AWS Batch.")
@click.option("--tmpfs-size", help="tmpfs requirement for AWS Batch.")
Expand Down Expand Up @@ -173,6 +180,7 @@ def step(
max_swap=None,
swappiness=None,
inferentia=None,
efa=None,
use_tmpfs=None,
tmpfs_tempdir=None,
tmpfs_size=None,
Expand Down Expand Up @@ -300,6 +308,7 @@ def _sync_metadata():
max_swap=max_swap,
swappiness=swappiness,
inferentia=inferentia,
efa=efa,
env=env,
attrs=attrs,
host_volumes=host_volumes,
Expand Down
33 changes: 33 additions & 0 deletions metaflow/plugins/aws/batch/batch_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def _register_job_definition(
max_swap,
swappiness,
inferentia,
efa,
memory,
host_volumes,
use_tmpfs,
Expand Down Expand Up @@ -310,6 +311,31 @@ def _register_job_definition(
}
]

if efa:
if not (isinstance(efa, (int, unicode, basestring))):
raise BatchJobException(
"Invalid efa value: ({}) (should be 0 or greater)".format(efa)
)
else:
job_definition["containerProperties"]["linuxParameters"]["devices"] = []
if (num_parallel or 0) > 1:
# Multi-node parallel jobs require the container path and permissions explicitly specified in Job definition
for i in range(int(efa)):
job_definition["containerProperties"]["linuxParameters"][
"devices"
].append(
{
"hostPath": "/dev/infiniband/uverbs{}".format(i),
"containerPath": "/dev/infiniband/uverbs{}".format(i),
"permissions": ["READ", "WRITE", "MKNOD"],
}
)
else:
# Single-node container jobs only require host path in job definition
job_definition["containerProperties"]["linuxParameters"][
"devices"
].append({"hostPath": "/dev/infiniband/uverbs0"})

self.num_parallel = num_parallel or 0
if self.num_parallel >= 1:
job_definition["type"] = "multinode"
Expand All @@ -332,6 +358,7 @@ def _register_job_definition(
"container": job_definition["containerProperties"],
}
)

del job_definition["containerProperties"] # not used for multi-node

# check if job definition already exists
Expand Down Expand Up @@ -371,6 +398,7 @@ def job_def(
max_swap,
swappiness,
inferentia,
efa,
memory,
host_volumes,
use_tmpfs,
Expand All @@ -388,6 +416,7 @@ def job_def(
max_swap,
swappiness,
inferentia,
efa,
memory,
host_volumes,
use_tmpfs,
Expand Down Expand Up @@ -438,6 +467,10 @@ def inferentia(self, inferentia):
self._inferentia = inferentia
return self

def efa(self, efa):
self._efa = efa
return self

def command(self, command):
if "command" not in self.payload["containerOverrides"]:
self.payload["containerOverrides"]["command"] = []
Expand Down
3 changes: 3 additions & 0 deletions metaflow/plugins/aws/batch/batch_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ class BatchDecorator(StepDecorator):
Path to tmpfs mount for this step. Defaults to /metaflow_temp.
inferentia : int, default: 0
Number of Inferentia chips required for this step.
efa: int, default: 0
Number of elastic fabric adapter network devices to attach to container
"""

name = "batch"
Expand All @@ -98,6 +100,7 @@ class BatchDecorator(StepDecorator):
"max_swap": None,
"swappiness": None,
"inferentia": None,
"efa": None,
"host_volumes": None,
"use_tmpfs": False,
"tmpfs_tempdir": True,
Expand Down
2 changes: 2 additions & 0 deletions metaflow/plugins/aws/step_functions/step_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ def _compile(self):
"Deploying flows with @trigger or @trigger_on_finish decorator(s) "
"to AWS Step Functions is not supported currently."
)

# Visit every node of the flow and recursively build the state machine.
def _visit(node, workflow, exit_node=None):
if node.parallel_foreach:
Expand Down Expand Up @@ -705,6 +706,7 @@ def _batch(self, node):
shared_memory=resources["shared_memory"],
max_swap=resources["max_swap"],
swappiness=resources["swappiness"],
efa=resources["efa"],
use_tmpfs=resources["use_tmpfs"],
tmpfs_tempdir=resources["tmpfs_tempdir"],
tmpfs_size=resources["tmpfs_size"],
Expand Down

0 comments on commit b76e512

Please sign in to comment.