Skip to content

Commit

Permalink
options for flow level decorators
Browse files Browse the repository at this point in the history
  • Loading branch information
madhur-ob committed Jun 6, 2024
1 parent 3fa8a67 commit 0fced00
Showing 1 changed file with 36 additions and 7 deletions.
43 changes: 36 additions & 7 deletions metaflow/runner/click_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
from metaflow.cli import start
from metaflow.includefile import FilePathClass
from metaflow.parameters import JSONTypeClass
from metaflow.plugins import FLOW_DECORATORS_DESC

FLOW_DECORATORS_DESC_DICT = {t[0]: t[1] for t in FLOW_DECORATORS_DESC}

click_to_python_types = {
StringParamType: str,
Expand Down Expand Up @@ -140,8 +143,17 @@ def get_inspect_param_obj(p: Union[click.Argument, click.Option], kind: str):
parsed_modules = {}


def get_options_from_classname(class_name):
module_name, class_name = class_name.rsplit(".", 1)
module = __import__(module_name, fromlist=[class_name])
kls = getattr(module, class_name)
options = getattr(kls, "options", None)
return options


class FlowSpecVisitor(ast.NodeVisitor):
def __init__(self):
self.flow_level_options = {}
self.parameters = []

def construct_parameter(self, node):
Expand Down Expand Up @@ -169,6 +181,15 @@ def visit_ClassDef(self, node):
if any(
isinstance(base, ast.Name) and base.id == "FlowSpec" for base in node.bases
):
if node.decorator_list:
for each_decorator in node.decorator_list:
decorator_name = each_decorator.func.id
if decorator_name in FLOW_DECORATORS_DESC_DICT:
class_name = FLOW_DECORATORS_DESC_DICT[decorator_name]
options = get_options_from_classname(class_name)
if options:
self.flow_level_options.update(options)

# Visit all the attributes in the class
for body_item in node.body:
# Check if the attribute is an instance of Parameter
Expand All @@ -191,21 +212,24 @@ def visit_ClassDef(self, node):
self.generic_visit(node)


def extract_flowspec_params(flow_file: str) -> List[Parameter]:
def extract_flowspec_and_top_level_params(flow_file: str):
# Check if the module has already been parsed
if flow_file in parsed_modules:
return parsed_modules[flow_file]
return (
parsed_modules[flow_file].parameters,
parsed_modules[flow_file].flow_level_options,
)

with open(flow_file, "r") as file:
tree = ast.parse(file.read(), filename=flow_file)

visitor = FlowSpecVisitor()
visitor.visit(tree)

# Cache the parsed parameters
parsed_modules[flow_file] = visitor.parameters
# Cache the parsed parameters and flow_level_options
parsed_modules[flow_file] = visitor

return visitor.parameters
return visitor.parameters, visitor.flow_level_options


class MetaflowAPI(object):
Expand All @@ -225,7 +249,12 @@ def chain(self):

@classmethod
def from_cli(cls, flow_file: str, cli_collection: Callable) -> Callable:
flow_parameters = extract_flowspec_params(flow_file)
flow_parameters, flow_options = extract_flowspec_and_top_level_params(flow_file)

if flow_options:
for name, kwargs in flow_options.items():
cli_collection.params.insert(0, click.Option(("--" + name,), **kwargs))

class_dict = {"__module__": "metaflow", "_API_NAME": flow_file}
command_groups = cli_collection.sources
for each_group in command_groups:
Expand Down Expand Up @@ -428,7 +457,7 @@ def _method(_self, **kwargs):
tags=["abc", "def"],
decospecs=["kubernetes"],
max_workers=5,
alpha=3,
myparam=3,
)
print(" ".join(command))

Expand Down

0 comments on commit 0fced00

Please sign in to comment.