Skip to content

Commit

Permalink
Update stubs
Browse files Browse the repository at this point in the history
  • Loading branch information
romain-intel committed Oct 22, 2024
1 parent 373ab7f commit 873a87d
Show file tree
Hide file tree
Showing 19 changed files with 239 additions and 97 deletions.
4 changes: 1 addition & 3 deletions metaflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,7 @@ class and related decorators.
# Flow spec
from .flowspec import FlowSpec

from .parameters import Parameter, JSONTypeClass

JSONType = JSONTypeClass()
from .parameters import Parameter, JSONTypeClass, JSONType

# data layer
# For historical reasons, we make metaflow.plugins.datatools accessible as
Expand Down
180 changes: 139 additions & 41 deletions metaflow/cmd/develop/stub_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,9 @@ def __init__(self, output_dir: str, include_generated_for: bool = True):
self._mf_version = get_version()

# Contains the names of the methods that are injected in Deployer
self._deployer_injected_methods = {} # type: Dict[str, str]
self._deployer_injected_methods = (
{}
) # type: Dict[str, Dict[str, Union[Tuple[str, str], str]]]
# Contains information to add to the Current object (injected by decorators)
self._addl_current = (
dict()
Expand Down Expand Up @@ -349,7 +351,9 @@ def _reset(self):
self._current_parent_module = None # type: Optional[ModuleType]

def _get_module_name_alias(self, module_name):
if not module_name.startswith(self._root_module):
if any(
module_name.startswith(x) for x in self._safe_modules
) and not module_name.startswith(self._root_module):
return self._root_module + ".".join(
["mf_extensions", *module_name.split(".")[1:]]
)
Expand Down Expand Up @@ -463,9 +467,13 @@ def _get_module(self, alias, name):
"Adding reference %s and adding module %s as %s"
% (objname, parent_module.__name__, parent_alias)
)

obj_import_name = getattr(obj, "__name__", objname)
if obj_import_name == "<lambda>":
# We have one case of this
obj_import_name = objname
self._current_references.append(
"from %s import %s as %s" % (relative_import, objname, objname)
"from %s import %s as %s"
% (relative_import, obj_import_name, objname)
)
self._pending_modules.append((parent_alias, parent_module.__name__))
else:
Expand Down Expand Up @@ -539,11 +547,12 @@ def _add_to_typing_check(name, is_module=False):
return "None"
return element.__name__

module_name = self._get_module_name_alias(module.__name__)
if force_import:
_add_to_import(module.__name__.split(".")[0])
_add_to_typing_check(module.__name__, is_module=True)
if module.__name__ != self._current_module_name:
return "{0}.{1}".format(module.__name__, element.__name__)
_add_to_import(module_name.split(".")[0])
_add_to_typing_check(module_name, is_module=True)
if module_name != self._current_module_name:
return "{0}.{1}".format(module_name, element.__name__)
else:
return element.__name__
elif isinstance(element, type(Ellipsis)):
Expand Down Expand Up @@ -577,7 +586,7 @@ def _add_to_typing_check(name, is_module=False):
else:
return "%s[%s]" % (element.__origin__, ", ".join(args_str))
elif isinstance(element, ForwardRef):
f_arg = element.__forward_arg__
f_arg = self._get_module_name_alias(element.__forward_arg__)
# if f_arg in ("Run", "Task"): # HACK -- forward references in current.py
# _add_to_import("metaflow")
# f_arg = "metaflow.%s" % f_arg
Expand All @@ -590,13 +599,17 @@ def _add_to_typing_check(name, is_module=False):
return "typing.NamedTuple"
return str(element)
else:
print(
"WARNING: Does not handle element %s of type %s"
% (element, type(element)),
flush=True,
)
_add_to_import("typing")
return "typing.Any"
if hasattr(element, "__module__"):
elem_module = self._get_module_name_alias(element.__module__)
if elem_module == "builtins":
return getattr(element, "__name__", str(element))
_add_to_typing_check(elem_module, is_module=True)
return "{0}.{1}".format(
elem_module, getattr(element, "__name__", element)
)
else:
# A constant
return str(element)

def _exploit_annotation(self, annotation: Any, starting: str = ": ") -> str:
annotation_string = ""
Expand All @@ -610,13 +623,11 @@ def _generate_class_stub(self, name: str, clazz: type) -> str:
debug.stubgen_exec("Generating class stub for %s" % name)
skip_init = issubclass(clazz, (TriggeredRun, DeployedFlow))
if issubclass(clazz, DeployerImpl):
if clazz.TYPE is None:
# Base DeployerImpl class -- we skip
return ""
clazz_type = clazz.TYPE.replace("-", "_")
self._deployer_injected_methods[clazz_type] = (
self._current_module_name + "." + name
)
if clazz.TYPE is not None:
clazz_type = clazz.TYPE.replace("-", "_")
self._deployer_injected_methods.setdefault(clazz_type, {})[
"deployer"
] = (self._current_module_name + "." + name)

buff = StringIO()
# Class prototype
Expand Down Expand Up @@ -693,22 +704,26 @@ def _generate_class_stub(self, name: str, clazz: type) -> str:
)

parameters, _ = parse_params_from_doc(docs["param_doc"])
return_type = self._deployer_injected_methods[element.__name__]
return_type = self._deployer_injected_methods[element.__name__][
"deployer"
]

buff.write(
self._generate_function_stub(
key,
element,
sign=inspect.Signature(
parameters=[
inspect.Parameter(
"self",
inspect.Parameter.POSITIONAL_OR_KEYWORD,
)
]
+ parameters,
return_annotation=return_type,
),
sign=[
inspect.Signature(
parameters=[
inspect.Parameter(
"self",
inspect.Parameter.POSITIONAL_OR_KEYWORD,
)
]
+ parameters,
return_annotation=return_type,
)
],
indentation=TAB,
deco=func_deco,
)
Expand Down Expand Up @@ -742,7 +757,8 @@ def _create_multi_type(*l):
return typing.Union[l]

all_types = [
v for v in self._deployer_injected_methods.values()
v["from_deployment"][0]
for v in self._deployer_injected_methods.values()
]

if len(all_types) > 1:
Expand All @@ -754,10 +770,18 @@ def _create_multi_type(*l):
self._generate_function_stub(
key,
element,
sign=inspect.Signature(
parameters=parameters,
return_annotation=return_type,
),
sign=[
inspect.Signature(
parameters=[
inspect.Parameter(
"cls",
inspect.Parameter.POSITIONAL_OR_KEYWORD,
)
]
+ parameters,
return_annotation=return_type,
)
],
indentation=TAB,
doc=docs["func_doc"]
+ "\n\nParameters\n----------\n"
Expand All @@ -767,7 +791,79 @@ def _create_multi_type(*l):
deco=func_deco,
)
)
elif (
clazz == DeployedFlow
and element.__name__.startswith("from_")
and element.__name__[5:] in self._deployer_injected_methods
):
# Get the doc from the from_deployment method stored in
# self._deployer_injected_methods
func_doc = inspect.cleandoc(
self._deployer_injected_methods[element.__name__[5:]][
"from_deployment"
][1]
)
docs = split_docs(
func_doc,
[
("func_doc", StartEnd(0, 0)),
(
"param_doc",
param_section_header.search(func_doc)
or StartEnd(len(func_doc), len(func_doc)),
),
(
"return_doc",
return_section_header.search(func_doc)
or StartEnd(len(func_doc), len(func_doc)),
),
],
)

parameters, _ = parse_params_from_doc(docs["param_doc"])
return_type = self._deployer_injected_methods[
element.__name__[5:]
]["from_deployment"][0]

buff.write(
self._generate_function_stub(
key,
element,
sign=[
inspect.Signature(
parameters=[
inspect.Parameter(
"cls",
inspect.Parameter.POSITIONAL_OR_KEYWORD,
)
]
+ parameters,
return_annotation=return_type,
)
],
indentation=TAB,
doc=docs["func_doc"]
+ "\n\nParameters\n----------\n"
+ docs["param_doc"]
+ "\n\nReturns\n-------\n"
+ docs["return_doc"],
deco=func_deco,
)
)
else:
if (
issubclass(clazz, DeployedFlow)
and clazz.TYPE is not None
and key == "from_deployment"
):
clazz_type = clazz.TYPE.replace("-", "_")
# Record docstring for this function
self._deployer_injected_methods.setdefault(clazz_type, {})[
"from_deployment"
] = (
self._current_module_name + "." + name,
element.__doc__,
)
buff.write(
self._generate_function_stub(
key,
Expand Down Expand Up @@ -1086,7 +1182,8 @@ def exploit_default(default_value: Any) -> Optional[str]:
)
elif isinstance(default_value, str):
return "'" + default_value + "'"
return getattr(default_value, "__name__", str(default_value))
else:
return self._get_element_name_with_module(default_value)

elif str(default_value).startswith("<"):
if default_value.__module__ == "builtins":
Expand Down Expand Up @@ -1131,7 +1228,8 @@ def exploit_default(default_value: Any) -> Optional[str]:
if count > 0:
buff.write("\n")

if do_overload and count < len(sign) - 1:
if do_overload: # According to mypy, we should have this on all
# so not excluding last one;; and count < len(sign) - 1:
buff.write(indentation + "@typing.overload\n")
if deco:
buff.write(indentation + deco + "\n")
Expand Down
1 change: 1 addition & 0 deletions metaflow/extension_support/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def resolve_plugins(category):
"environment": lambda x: x.TYPE,
"metadata_provider": lambda x: x.TYPE,
"datastore": lambda x: x.TYPE,
"dataclient": lambda x: x.TYPE,
"secrets_provider": lambda x: x.TYPE,
"gcp_client_provider": lambda x: x.name,
"deployer_impl_provider": lambda x: x.TYPE,
Expand Down
13 changes: 3 additions & 10 deletions metaflow/includefile.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections import namedtuple
import gzip

import importlib
import io
import json
import os
Expand All @@ -17,6 +18,8 @@
Parameter,
ParameterContext,
)

from .plugins import DATACLIENTS
from .util import get_username

import functools
Expand Down Expand Up @@ -47,16 +50,6 @@


# From here on out, this is the IncludeFile implementation.
from metaflow.plugins.datatools import Local, S3
from metaflow.plugins.azure.includefile_support import Azure
from metaflow.plugins.gcp.includefile_support import GS

DATACLIENTS = {
"local": Local,
"s3": S3,
"azure": Azure,
"gs": GS,
}


class IncludedFile(object):
Expand Down
3 changes: 3 additions & 0 deletions metaflow/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,3 +438,6 @@ def wrapper(cmd):
return cmd

return wrapper


JSONType = JSONTypeClass()
9 changes: 9 additions & 0 deletions metaflow/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,14 @@
("gs", ".datastores.gs_storage.GSStorage"),
]

# Dataclients are used for IncludeFile
DATACLIENTS_DESC = [
("local", ".datatools.Local"),
("s3", ".datatools.S3"),
("azure", ".azure.includefile_support.Azure"),
("gs", ".gcp.includefile_support.GS"),
]

# Add non monitoring/logging sidecars here
SIDECARS_DESC = [
(
Expand Down Expand Up @@ -161,6 +169,7 @@ def get_plugin_cli():
ENVIRONMENTS = resolve_plugins("environment")
METADATA_PROVIDERS = resolve_plugins("metadata_provider")
DATASTORES = resolve_plugins("datastore")
DATACLIENTS = resolve_plugins("dataclient")
SIDECARS = resolve_plugins("sidecar")
LOGGING_SIDECARS = resolve_plugins("logging_sidecar")
MONITOR_SIDECARS = resolve_plugins("monitor_sidecar")
Expand Down
12 changes: 9 additions & 3 deletions metaflow/plugins/argo/argo_workflows_deployer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from metaflow.runner.deployer_impl import DeployerImpl

if TYPE_CHECKING:
from .argo_workflows_deployer_objects import ArgoWorkflowsDeployedFlow
import metaflow.plugins.argo.argo_workflows_deployer_objects


class ArgoWorkflowsDeployer(DeployerImpl):
Expand Down Expand Up @@ -38,12 +38,18 @@ def deployer_kwargs(self) -> Dict[str, Any]:
return self._deployer_kwargs

@staticmethod
def deployed_flow_type() -> Type["ArgoWorkflowsDeployedFlow"]:
def deployed_flow_type() -> (
Type[
"metaflow.plugins.argo.argo_workflows_deployer_objects.ArgoWorkflowsDeployedFlow"
]
):
from .argo_workflows_deployer_objects import ArgoWorkflowsDeployedFlow

return ArgoWorkflowsDeployedFlow

def create(self, **kwargs) -> "ArgoWorkflowsDeployedFlow":
def create(
self, **kwargs
) -> "metaflow.plugins.argo.argo_workflows_deployer_objects.ArgoWorkflowsDeployedFlow":
"""
Create a new ArgoWorkflow deployment.
Expand Down
Loading

0 comments on commit 873a87d

Please sign in to comment.