diff --git a/metaflow/graph.py b/metaflow/graph.py index 5013971eb28..3d199959359 100644 --- a/metaflow/graph.py +++ b/metaflow/graph.py @@ -3,6 +3,7 @@ import re from itertools import chain +from typing import List, Optional from .util import to_pod @@ -82,6 +83,16 @@ def __init__( self.matching_join = None # these attributes are populated by _postprocess self.is_inside_foreach = False + # Conditional info + self.is_conditional = False # will this node always be executed, or is it in a conditional branch? + self.matching_conditional_join = None # which step joins the conditional branches. filled for split-switch only. + self.is_conditional_join = ( + False # Does this node 'join' some set of conditional branches? + ) + self.conditional_parents = [] + self.conditional_branch = ( + [] + ) # All the steps leading to this node that depends on a condition, starting from the split-switch def _expr_str(self, expr): return "%s.%s" % (expr.value.id, expr.attr) @@ -296,8 +307,61 @@ def _postprocess(self): if [f for f in foreaches if self.nodes[f].matching_join != node.name]: node.is_inside_foreach = True + # Fill in conditionals related info. + if node.conditional_parents: + # do the required postprocessing for anything requiring node.in_funcs + + # check that in previous parsing we have not closed all conditional in_funcs. + # If so, this step can not be conditional either + node.is_conditional = any( + self[in_func].is_conditional or self[in_func].type == "split-switch" + for in_func in node.in_funcs + ) + # does this node close the latest conditional parent branches? + conditional_in_funcs = [ + in_func + for in_func in node.in_funcs + if self[in_func].conditional_branch + ] + closed_conditional_parents = [] + for last_split_switch in node.conditional_parents[::-1]: + # last_split_switch = node.conditional_parents[-1] + last_conditional_split_nodes = self[last_split_switch].out_funcs + # p needs to be in at least one conditional_branch for it to be closed. + if all( + any( + p in self[in_func].conditional_branch + for in_func in conditional_in_funcs + ) + for p in last_conditional_split_nodes + ): + closed_conditional_parents.append(last_split_switch) + + node.is_conditional_join = True + self[last_split_switch].matching_conditional_join = node.name + + # Did we close all conditionals? Then this branch and all its children are not conditional anymore. + if not [ + p + for p in node.conditional_parents + if p not in closed_conditional_parents + ]: + node.is_conditional = False + node.conditional_parents = [] + for p in node.out_funcs: + child = self[p] + child.is_conditional = False + child.conditional_parents = [] + def _traverse_graph(self): - def traverse(node, seen, split_parents, split_branches): + def traverse( + node: DAGNode, + seen, + split_parents, + split_branches, + conditional_branch: List[str], + conditional_parents: Optional[List[str]] = None, + ): add_split_branch = False try: self.sorted_nodes.remove(node.name) @@ -312,6 +376,13 @@ def traverse(node, seen, split_parents, split_branches): elif node.type == "split-switch": node.split_parents = split_parents node.split_branches = split_branches + + node.conditional_branch = conditional_branch + [node.name] + node.conditional_parents = ( + [node.name] + if not conditional_parents + else conditional_parents + [node.name] + ) elif node.type == "join": # ignore joins without splits if split_parents: @@ -324,6 +395,11 @@ def traverse(node, seen, split_parents, split_branches): node.split_parents = split_parents node.split_branches = split_branches + if conditional_parents and not node.type == "split-switch": + node.conditional_parents = conditional_parents + node.conditional_branch = conditional_branch + [node.name] + node.is_conditional = True + for n in node.out_funcs: # graph may contain loops - ignore them if n not in seen: @@ -336,10 +412,12 @@ def traverse(node, seen, split_parents, split_branches): seen + [n], split_parents, split_branches + ([n] if add_split_branch else []), + node.conditional_branch, + node.conditional_parents, ) if "start" in self: - traverse(self["start"], [], [], []) + traverse(self["start"], [], [], [], []) # fix the order of in_funcs for node in self.nodes.values(): diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index 7e6ac43ce95..1fb21a08636 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -941,6 +941,7 @@ def _visit( dag_tasks = [] if templates is None: templates = [] + if exit_node is not None and exit_node is node.name: return templates, dag_tasks if node.name == "start": @@ -948,12 +949,7 @@ def _visit( dag_task = DAGTask(self._sanitize(node.name)).template( self._sanitize(node.name) ) - if node.type == "split-switch": - raise ArgoWorkflowsException( - "Deploying flows with switch statement " - "to Argo Workflows is not supported currently." - ) - elif ( + if ( node.is_inside_foreach and self.graph[node.in_funcs[0]].type == "foreach" and not self.graph[node.in_funcs[0]].parallel_foreach @@ -1087,15 +1083,43 @@ def _visit( ] ) + conditional_deps = [ + "%s.Succeeded" % self._sanitize(in_func) + for in_func in node.in_funcs + if self.graph[in_func].is_conditional + ] + required_deps = [ + "%s.Succeeded" % self._sanitize(in_func) + for in_func in node.in_funcs + if not self.graph[in_func].is_conditional + ] + both_conditions = required_deps and conditional_deps + + depends_str = "{required}{_and}{conditional}".format( + required=("(%s)" if both_conditions else "%s") + % " && ".join(required_deps), + _and=" && " if both_conditions else "", + conditional=("(%s)" if both_conditions else "%s") + % " || ".join(conditional_deps), + ) dag_task = ( DAGTask(self._sanitize(node.name)) - .dependencies( - [self._sanitize(in_func) for in_func in node.in_funcs] - ) + .depends(depends_str) .template(self._sanitize(node.name)) .arguments(Arguments().parameters(parameters)) ) + # Add conditional if this is the first step in a conditional branch + if ( + node.is_conditional + and self.graph[node.in_funcs[0]].type == "split-switch" + ): + in_func = node.in_funcs[0] + dag_task.when( + "{{tasks.%s.outputs.parameters.switch-step}}==%s" + % (self._sanitize(in_func), node.name) + ) + dag_tasks.append(dag_task) # End the workflow if we have reached the end of the flow if node.type == "end": @@ -1121,6 +1145,23 @@ def _visit( dag_tasks, parent_foreach, ) + elif node.type == "split-switch": + for n in node.out_funcs: + _visit( + self.graph[n], + node.matching_conditional_join, + templates, + dag_tasks, + parent_foreach, + ) + + return _visit( + self.graph[node.matching_conditional_join], + exit_node, + templates, + dag_tasks, + parent_foreach, + ) # For foreach nodes generate a new sub DAGTemplate # We do this for "regular" foreaches (ie. `self.next(self.a, foreach=)`) elif node.type == "foreach": @@ -1148,7 +1189,7 @@ def _visit( # foreach_task = ( DAGTask(foreach_template_name) - .dependencies([self._sanitize(node.name)]) + .depends(f"{self._sanitize(node.name)}.Succeeded") .template(foreach_template_name) .arguments( Arguments().parameters( @@ -1193,6 +1234,15 @@ def _visit( % self._sanitize(node.name) ) ) + # Add conditional if this is the first step in a conditional branch + if node.is_conditional and not any( + self.graph[in_func].is_conditional for in_func in node.in_funcs + ): + in_func = node.in_funcs[0] + foreach_task.when( + "{{tasks.%s.outputs.parameters.switch-step}}==%s" + % (self._sanitize(in_func), node.name) + ) dag_tasks.append(foreach_task) templates, dag_tasks_1 = _visit( self.graph[node.out_funcs[0]], @@ -1236,7 +1286,22 @@ def _visit( self.graph[node.matching_join].in_funcs[0] ) } - ) + if not self.graph[ + node.matching_join + ].is_conditional_join + else + # Note: If the nodes leading to the join are conditional, then we need to use an expression to pick the outputs from the task that executed. + # ref for operators: https://github.com/expr-lang/expr/blob/master/docs/language-definition.md + { + "expression": "get((%s)?.parameters, 'task-id')" + % " ?? ".join( + f"tasks['{self._sanitize(func)}']?.outputs" + for func in self.graph[ + node.matching_join + ].in_funcs + ) + } + ), ] if not node.parallel_foreach else [ @@ -1269,7 +1334,7 @@ def _visit( join_foreach_task = ( DAGTask(self._sanitize(self.graph[node.matching_join].name)) .template(self._sanitize(self.graph[node.matching_join].name)) - .dependencies([foreach_template_name]) + .depends(f"{foreach_template_name}.Succeeded") .arguments( Arguments().parameters( ( @@ -1378,6 +1443,7 @@ def _container_templates(self): task_idx = "" input_paths = "" + input_paths_filename = "" root_input = None # export input_paths as it is used multiple times in the container script # and we do not want to repeat the values. @@ -1568,6 +1634,11 @@ def _container_templates(self): ] ) input_paths = "%s/_parameters/%s" % (run_id, task_id_params) + elif node.is_conditional_join: + input_paths_filename = ( + "$(python -m metaflow.plugins.argo.create_input_paths_file '%s' '%s')" + % (",".join(node.in_funcs), f"{self.flow.name}/{run_id}") + ) elif ( node.type == "join" and self.graph[node.split_parents[-1]].type == "foreach" @@ -1600,7 +1671,11 @@ def _container_templates(self): "--task-id %s" % task_id, "--retry-count %s" % retry_count, "--max-user-code-retries %d" % user_code_retries, - "--input-paths %s" % input_paths, + ( + "--input-paths-filename %s" % input_paths_filename + if input_paths_filename + else "--input-paths %s" % input_paths + ), ] if node.parallel_step: step.append( @@ -1814,7 +1889,7 @@ def _container_templates(self): [Parameter("num-parallel"), Parameter("task-id-entropy")] ) else: - # append this only for joins of foreaches, not static splits + # append these only for joins of foreaches, not static splits inputs.append(Parameter("split-cardinality")) # check if the node is a @parallel node. elif node.parallel_step: @@ -1849,6 +1924,13 @@ def _container_templates(self): # are derived at runtime. if not (node.name == "end" or node.parallel_step): outputs = [Parameter("task-id").valueFrom({"path": "/mnt/out/task_id"})] + + # If this step is a split-switch one, we need to output the switch step name + if node.type == "split-switch": + outputs.append( + Parameter("switch-step").valueFrom({"path": "/mnt/out/switch_step"}) + ) + if node.type == "foreach": # Emit split cardinality from foreach task outputs.append( @@ -3981,6 +4063,10 @@ def dependencies(self, dependencies): self.payload["dependencies"] = dependencies return self + def depends(self, depends: str): + self.payload["depends"] = depends + return self + def template(self, template): # Template reference self.payload["template"] = template @@ -3992,6 +4078,10 @@ def inline(self, template): self.payload["inline"] = template.to_json() return self + def when(self, when: str): + self.payload["when"] = when + return self + def with_param(self, with_param): self.payload["withParam"] = with_param return self diff --git a/metaflow/plugins/argo/argo_workflows_decorator.py b/metaflow/plugins/argo/argo_workflows_decorator.py index ce92d34b5b4..67c8fd91363 100644 --- a/metaflow/plugins/argo/argo_workflows_decorator.py +++ b/metaflow/plugins/argo/argo_workflows_decorator.py @@ -123,6 +123,15 @@ def task_finished( with open("/mnt/out/split_cardinality", "w") as file: json.dump(flow._foreach_num_splits, file) + # For conditional branches we need to record the value of the switch to disk, in order to pass it as an + # output from the switching step to be used further down the DAG + if graph[step_name].type == "split-switch": + # TODO: A nicer way to access the chosen step? + _out_funcs, _ = flow._transition + chosen_step = _out_funcs[0] + with open("/mnt/out/switch_step", "w") as file: + file.write(chosen_step) + # For steps that have a `@parallel` decorator set to them, we will be relying on Jobsets # to run the task. In this case, we cannot set anything in the # `/mnt/out` directory, since such form of output mounts are not available to Jobset executions. diff --git a/metaflow/plugins/argo/create_input_paths_file.py b/metaflow/plugins/argo/create_input_paths_file.py new file mode 100644 index 00000000000..619b2b8e9bc --- /dev/null +++ b/metaflow/plugins/argo/create_input_paths_file.py @@ -0,0 +1,34 @@ +import sys +from metaflow import Run +from metaflow.util import compress_list +from tempfile import NamedTemporaryFile + +# This utility uses the Metadata Service to fetch completed tasks for the provided step_names for a specific run +# and writes them out to a file, returning the file path as a result. +# This is required due to Foreach split output steps not being deterministic anymore after introducing conditional branching to Metaflow, as during graph parsing we now only know the set of possible steps that leads to the executing step. + + +def fetch_input_paths(step_names, run_pathspec): + steps = step_names.split(",") + run = Run(run_pathspec, _namespace_check=False) + + input_paths = [] + for step in steps: + try: + # for input paths we require the pathspec without the Flow name + input_paths.extend(f"{run.id}/{step}/{task.id}" for task in run[step]) + except KeyError: + # a step might not have ever executed due to it being conditional. + pass + + return input_paths + + +if __name__ == "__main__": + input_paths = fetch_input_paths(sys.argv[1], sys.argv[2]) + # we use the Metaflow internal compress_list due to --input-paths-filename processing relying on decompress_list. + compressed = compress_list(input_paths) + + with NamedTemporaryFile(delete=False) as f: + f.write(compressed.encode("utf-8")) + print(f.name)