diff --git a/metaflow/graph.py b/metaflow/graph.py index 227232791c..6002f5405c 100644 --- a/metaflow/graph.py +++ b/metaflow/graph.py @@ -2,12 +2,49 @@ import ast import re +def deindent_docstring(doc): + if doc: + # Find the indent to remove from the doctring. We consider the following possibilities: + # Option 1: + # """This is the first line + # This is the second line + # """ + # Option 2: + # """ + # This is the first line + # This is the second line + # """ + # Option 3: + # """ + # This is the first line + # This is the second line + # """ + # + # In all cases, we can find the indent to remove by doing the following: + # - Check the first non-empty line, if it has an indent, use that as the base indent + # - If it does not have an indent and there is a second line, check the indent of the + # second line and use that + saw_first_line = False + matched_indent = None + for line in doc.splitlines(): + if line: + matched_indent = re.match('[\t ]+', line) + if matched_indent is not None or saw_first_line: + break + saw_first_line = True + if matched_indent: + return re.sub(r'\n' + matched_indent.group(), '\n', doc).strip() + else: + return doc + else: + return '' + class DAGNode(object): def __init__(self, func_ast, decos, doc): self.name = func_ast.name self.func_lineno = func_ast.lineno self.decorators = decos - self.doc = doc.rstrip() + self.doc = deindent_docstring(doc) # these attributes are populated by _parse self.tail_next_lineno = 0 @@ -111,46 +148,32 @@ def __str__(self): class StepVisitor(ast.NodeVisitor): - def __init__(self, nodes): + def __init__(self, nodes, flow): self.nodes = nodes + self.flow = flow super(StepVisitor, self).__init__() def visit_FunctionDef(self, node): - decos = [d.func.id if isinstance(d, ast.Call) else d.id - for d in node.decorator_list] - if 'step' in decos: - doc = ast.get_docstring(node) - self.nodes[node.name] = DAGNode(node, decos, doc if doc else '') + func = getattr(self.flow, node.name) + if hasattr(func, 'is_step'): + self.nodes[node.name] = DAGNode(node, func.decorators, func.__doc__) class FlowGraph(object): - def __init__(self, flow=None, source=None, name=None): - if flow: - module = __import__(flow.__module__) - source = inspect.getsource(module) - self.name = flow.__name__ - else: - self.name = name - - self.nodes = self._create_nodes(source) + def __init__(self, flow): + self.name = flow.__name__ + self.nodes = self._create_nodes(flow) + self.doc = deindent_docstring(flow.__doc__) self._traverse_graph() self._postprocess() - def _create_nodes(self, source): - def _flow(n): - if isinstance(n, ast.ClassDef): - bases = [b.id for b in n.bases] - if 'FlowSpec' in bases: - return self.name is None or n.name == self.name - - # NOTE: this will fail if a file has multiple FlowSpec classes - # and no name is specified - [root] = list(filter(_flow, ast.parse(source).body)) - self.name = root.name - doc = ast.get_docstring(root) - self.doc = doc if doc else '' + def _create_nodes(self, flow): + module = __import__(flow.__module__) + tree = ast.parse(inspect.getsource(module)).body + root = [n for n in tree\ + if isinstance(n, ast.ClassDef) and n.name == self.name][0] nodes = {} - StepVisitor(nodes).visit(root) + StepVisitor(nodes, flow).visit(root) return nodes def _postprocess(self):