Skip to content

Commit

Permalink
bake images in parallel to speed things up
Browse files Browse the repository at this point in the history
  • Loading branch information
saikonen committed Jun 25, 2024
1 parent e450039 commit 2c43644
Showing 1 changed file with 82 additions and 42 deletions.
124 changes: 82 additions & 42 deletions metaflow/plugins/docker/docker_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import os

from concurrent.futures import ThreadPoolExecutor
from metaflow.exception import MetaflowException
from metaflow.metaflow_config import (
_USE_BAKERY,
Expand Down Expand Up @@ -75,8 +76,44 @@ def init_environment(self, echo):
self._init_conda_fallback()
echo("Baking Docker images for environment(s) ...")
self.bakery = FastBakery(url=FAST_BAKERY_URL, auth_type=FAST_BAKERY_AUTH)
# 1. Check cache for known images, set them to the step if found
# 2. if no known image in the cache, we set up steps for a later baking process
# 3. bake all missing images in parallel, gather results.
# 4. set baked image tags for the missing steps.
not_cached_steps = {}
for step in self.flow:
self.bake_image_for_step(step)
if self._is_delegated(step.name):
continue
spec_hash = self.cache_hash_for_step(step)
image = get_cache_image_tag(spec_hash)
if image:
self.set_image_for_step(step, image)
continue
# we don't have a cached image for the step yet, set up for baking.
if spec_hash in not_cached_steps:
not_cached_steps[spec_hash].append(step)
else:
not_cached_steps[spec_hash] = [step]

def _bake(args):
spec_hash, steps = args
# we only need to perform the bake request for one of the steps per hash
step = steps[0]
image, request = self.bake_image_for_step(step)
return spec_hash, image, request

# bake missing images in parallel
results = []
with ThreadPoolExecutor() as executor:
results = list(executor.map(_bake, not_cached_steps.items()))

# set the newly baked images for steps and cache the images for later use
for res in results:
spec_hash, image, request = res
for step in not_cached_steps[spec_hash]:
self.set_image_for_step(step, image)
cache_image_tag(spec_hash, image, request)

echo("Environments are ready!")
if self.steps_to_delegate:
# TODO: add debug echo to output steps that required a conda environment.
Expand All @@ -85,10 +122,32 @@ def init_environment(self, echo):
self.delegate.validate_environment(echo, self.datastore_type)
self.delegate.init_environment(echo, self.steps_to_delegate)

def bake_image_for_step(self, step):
if self._is_delegated(step.name):
# do not bake images for delegated steps
return
def set_image_for_step(self, step, image):
# we have an image that we need to set to a kubernetes or batch decorator.
for deco in step.decorators:
if _is_remote_deco(deco):
deco.attributes["image"] = image

def cache_hash_for_step(self, step):
(
base_image,
python_version,
pypi_pkg,
conda_pkg,
deco_name,
) = self.details_from_step(step)

packages = {**(pypi_pkg or {}), **(conda_pkg or {})}
sorted_keys = sorted(packages.keys())
base_str = "".join(
[FAST_BAKERY_TYPE, python_version, base_image or "", deco_name]
)
sortspec = base_str.join("%s%s" % (k, packages[k]) for k in sorted_keys).encode(
"utf-8"
)
return hashlib.md5(sortspec).hexdigest()

def details_from_step(self, step):
# map out if user is requesting a base image to build on top of
base_image = None
for deco in step.decorators:
Expand All @@ -103,9 +162,6 @@ def bake_image_for_step(self, step):
(deco for deco in step.decorators if isinstance(deco, PyPIStepDecorator)),
None,
)
image = None
# pypi packages need to take precedence over conda.
# a conda decorator always exists alongside pypi so this needs to be accounted for
dependency_deco = pypi_deco if pypi_deco is not None else conda_deco
if dependency_deco is not None:
python = dependency_deco.attributes["python"]
Expand All @@ -119,26 +175,25 @@ def bake_image_for_step(self, step):
conda_pkg = pkgs
pypi_pkg = None

# Try getting image tag from cache first.
spec_hash = generate_spec_hash(
base_image, python, pkgs, dependency_deco.name
return base_image, python, pypi_pkg, conda_pkg, dependency_deco.name

def bake_image_for_step(self, step):
if self._is_delegated(step.name):
# do not bake images for delegated steps
return

base_image, python, pypi_pkg, conda_pkg, _deco_type = self.details_from_step(
step
)

try:
image, request = self.bakery.bake(
python, pypi_pkg, conda_pkg, base_image, FAST_BAKERY_TYPE
)
image = get_cache_image_tag(spec_hash)
if not image:
try:
image, request = self.bakery.bake(
python, pypi_pkg, conda_pkg, base_image, FAST_BAKERY_TYPE
)
except FastBakeryException as ex:
raise DockerEnvironmentException(str(ex))
# cache the baked image for later use
cache_image_tag(spec_hash, image, request)

if image is not None:
# we have an image that we need to set to a kubernetes or batch decorator.
for deco in step.decorators:
if _is_remote_deco(deco):
deco.attributes["image"] = image
except FastBakeryException as ex:
raise DockerEnvironmentException(str(ex))

return image, request

def _is_delegated(self, step_name):
return step_name in self.steps_to_delegate
Expand Down Expand Up @@ -205,20 +260,5 @@ def read_metafile():
return {}


def generate_spec_hash(
base_image=None, python_version=None, packages={}, resolver_type=None
):
sorted_keys = sorted(packages.keys())
base_str = "".join(
[FAST_BAKERY_TYPE, python_version, base_image or "", resolver_type]
)
sortspec = base_str.join("%s%s" % (k, packages[k]) for k in sorted_keys).encode(
"utf-8"
)
hash = hashlib.md5(sortspec).hexdigest()

return hash


def _is_remote_deco(deco):
return isinstance(deco, BatchDecorator) or isinstance(deco, KubernetesDecorator)

0 comments on commit 2c43644

Please sign in to comment.