From 87f3bf52a0cdbf4c307c9c6c17828b245b5505b7 Mon Sep 17 00:00:00 2001 From: Sakari Ikonen Date: Fri, 8 Aug 2025 18:53:30 +0300 Subject: [PATCH 01/15] wip --- metaflow/plugins/argo/argo_workflows.py | 54 ++++++++++++++++--- .../plugins/argo/argo_workflows_decorator.py | 7 +++ 2 files changed, 55 insertions(+), 6 deletions(-) diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index 7e6ac43ce95..51f7c678adb 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -928,6 +928,7 @@ def _visit( templates=None, dag_tasks=None, parent_foreach=None, + visited_nodes=None, ): # Returns Tuple[List[Template], List[DAGTask]] """ """ # Every for-each node results in a separate subDAG and an equivalent @@ -937,10 +938,23 @@ def _visit( # of the for-each node. # Emit if we have reached the end of the sub workflow + if visited_nodes is None: + visited_nodes = set() if dag_tasks is None: + print("RESET DAG_TASKS") dag_tasks = [] if templates is None: + print("RESET TEMPLATES") templates = [] + + # Break early if we have reached a node we already visited. Happens when parsing through all conditional branches of split-switch + if node.name in visited_nodes: + print(f"BROKE EARLY on step :{node.name}") + return templates, dag_tasks + else: + print(f"added to visited :{node.name}") + visited_nodes.add(node.name) + if exit_node is not None and exit_node is node.name: return templates, dag_tasks if node.name == "start": @@ -948,12 +962,7 @@ def _visit( dag_task = DAGTask(self._sanitize(node.name)).template( self._sanitize(node.name) ) - if node.type == "split-switch": - raise ArgoWorkflowsException( - "Deploying flows with switch statement " - "to Argo Workflows is not supported currently." - ) - elif ( + if ( node.is_inside_foreach and self.graph[node.in_funcs[0]].type == "foreach" and not self.graph[node.in_funcs[0]].parallel_foreach @@ -1113,6 +1122,7 @@ def _visit( templates, dag_tasks, parent_foreach, + visited_nodes, ) return _visit( self.graph[node.matching_join], @@ -1120,6 +1130,27 @@ def _visit( templates, dag_tasks, parent_foreach, + visited_nodes, + ) + elif node.type == "split-switch": + # Traverse all branches of a switch split. This should work as all branches lead to 'exit_node' + for n in node.out_funcs[:-1]: + _visit( + self.graph[n], + exit_node, + templates, + dag_tasks, + parent_foreach, + visited_nodes, + ) + + return _visit( + self.graph[node.out_funcs[-1:][0]], + exit_node, + templates, + dag_tasks, + parent_foreach, + visited_nodes, ) # For foreach nodes generate a new sub DAGTemplate # We do this for "regular" foreaches (ie. `self.next(self.a, foreach=)`) @@ -1200,6 +1231,7 @@ def _visit( templates, [], node.name, + visited_nodes, ) # How do foreach's work on Argo: @@ -1318,6 +1350,7 @@ def _visit( templates, dag_tasks, parent_foreach, + visited_nodes, ) # For linear nodes continue traversing to the next node if node.type in ("linear", "join", "start"): @@ -1327,6 +1360,7 @@ def _visit( templates, dag_tasks, parent_foreach, + visited_nodes, ) else: raise ArgoWorkflowsException( @@ -1849,6 +1883,14 @@ def _container_templates(self): # are derived at runtime. if not (node.name == "end" or node.parallel_step): outputs = [Parameter("task-id").valueFrom({"path": "/mnt/out/task_id"})] + + # If this step is a split-switch one, we need to output the switch step name + # Note we can not use node.type for this, as the start step can also be a switching one + if node.type == "split-switch": + outputs.append( + Parameter("switch-step").valueFrom({"path": "/mnt/out/switch_step"}) + ) + if node.type == "foreach": # Emit split cardinality from foreach task outputs.append( diff --git a/metaflow/plugins/argo/argo_workflows_decorator.py b/metaflow/plugins/argo/argo_workflows_decorator.py index ce92d34b5b4..1a020b96db5 100644 --- a/metaflow/plugins/argo/argo_workflows_decorator.py +++ b/metaflow/plugins/argo/argo_workflows_decorator.py @@ -123,6 +123,13 @@ def task_finished( with open("/mnt/out/split_cardinality", "w") as file: json.dump(flow._foreach_num_splits, file) + # For conditional branches we need to record the value of the switch to disk, in order to pass it as an + # output from the switching step to be used further down the DAG + if graph[step_name].type == "switch-split": + switch_step_name = getattr(self, graph[step_name].condition) + with open("/mnt/out/switch_step", "w") as file: + json.dump(switch_step_name, file) + # For steps that have a `@parallel` decorator set to them, we will be relying on Jobsets # to run the task. In this case, we cannot set anything in the # `/mnt/out` directory, since such form of output mounts are not available to Jobset executions. From a00d57d7cb0f4e2d1f8f310cb9a5870a0e74c37a Mon Sep 17 00:00:00 2001 From: Sakari Ikonen Date: Mon, 11 Aug 2025 14:17:55 +0300 Subject: [PATCH 02/15] add conditional info to graph parsing --- metaflow/graph.py | 60 +++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 58 insertions(+), 2 deletions(-) diff --git a/metaflow/graph.py b/metaflow/graph.py index 5013971eb28..fd56d8000fb 100644 --- a/metaflow/graph.py +++ b/metaflow/graph.py @@ -3,6 +3,7 @@ import re from itertools import chain +from typing import List, Optional from .util import to_pod @@ -80,6 +81,9 @@ def __init__( self.split_parents = [] self.split_branches = [] self.matching_join = None + self.is_conditional = False # will this node always be executed, or is it in a conditional branch? + self.conditional_branch = [] + self.conditional_join = None # Node where conditional branches end, and further nodes always execute. # these attributes are populated by _postprocess self.is_inside_foreach = False @@ -297,7 +301,14 @@ def _postprocess(self): node.is_inside_foreach = True def _traverse_graph(self): - def traverse(node, seen, split_parents, split_branches): + def traverse( + node, + seen, + split_parents, + split_branches, + conditional_branch: List[str], + conditional_root_nodes: Optional[List[List[str]]] = None, + ): add_split_branch = False try: self.sorted_nodes.remove(node.name) @@ -312,6 +323,14 @@ def traverse(node, seen, split_parents, split_branches): elif node.type == "split-switch": node.split_parents = split_parents node.split_branches = split_branches + + conditional_branch = conditional_branch + [node.name] + node.conditional_branch = conditional_branch + conditional_root_nodes = ( + [node.out_funcs] + if not conditional_root_nodes + else conditional_root_nodes + [node.out_funcs] + ) elif node.type == "join": # ignore joins without splits if split_parents: @@ -324,6 +343,41 @@ def traverse(node, seen, split_parents, split_branches): node.split_parents = split_parents node.split_branches = split_branches + if conditional_root_nodes and not node.type == "split-switch": + conditional_branch = conditional_branch + [node.name] + node.conditional_branch = conditional_branch + # Multiple cases for conditional branching. TODO: describe the structure + # 1. we are in only one conditional branch + # 2. we are in a nested conditional branch + + *root_nodes, last_root_nodes = conditional_root_nodes + # Check if the node is joining all of the conditional root nodes branches. + is_conditional_join = all( + any(p in last_root_nodes for p in self[in_func].conditional_branch) + for in_func in node.in_funcs + ) + + if is_conditional_join: + conditional_root_nodes = root_nodes + + # we are in a conditional branch if we have conditional root nodes left open, and + # we did not join the most recent conditional branches. + is_in_conditional_branch = ( + bool(conditional_root_nodes) and not is_conditional_join + ) + + if not is_in_conditional_branch: + conditional_branch = [] + # add the conditional join step info + for n in set( + step + for in_func in node.in_funcs + for step in self[in_func].conditional_branch + ): + self[n].conditional_join = node.name + + node.is_conditional = is_in_conditional_branch + for n in node.out_funcs: # graph may contain loops - ignore them if n not in seen: @@ -336,10 +390,12 @@ def traverse(node, seen, split_parents, split_branches): seen + [n], split_parents, split_branches + ([n] if add_split_branch else []), + conditional_branch, + conditional_root_nodes, ) if "start" in self: - traverse(self["start"], [], [], []) + traverse(self["start"], [], [], [], []) # fix the order of in_funcs for node in self.nodes.values(): From 3ee6557f5ce6f4c07a105da88d7a943bb3fcff0a Mon Sep 17 00:00:00 2001 From: Sakari Ikonen Date: Mon, 11 Aug 2025 18:05:35 +0300 Subject: [PATCH 03/15] correctly dump chosen step to disk for argo --- metaflow/plugins/argo/argo_workflows_decorator.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/metaflow/plugins/argo/argo_workflows_decorator.py b/metaflow/plugins/argo/argo_workflows_decorator.py index 1a020b96db5..67c8fd91363 100644 --- a/metaflow/plugins/argo/argo_workflows_decorator.py +++ b/metaflow/plugins/argo/argo_workflows_decorator.py @@ -125,10 +125,12 @@ def task_finished( # For conditional branches we need to record the value of the switch to disk, in order to pass it as an # output from the switching step to be used further down the DAG - if graph[step_name].type == "switch-split": - switch_step_name = getattr(self, graph[step_name].condition) + if graph[step_name].type == "split-switch": + # TODO: A nicer way to access the chosen step? + _out_funcs, _ = flow._transition + chosen_step = _out_funcs[0] with open("/mnt/out/switch_step", "w") as file: - json.dump(switch_step_name, file) + file.write(chosen_step) # For steps that have a `@parallel` decorator set to them, we will be relying on Jobsets # to run the task. In this case, we cannot set anything in the From 0e99798a8937dc373d40a1a3a4513f6bbf9cd9b9 Mon Sep 17 00:00:00 2001 From: Sakari Ikonen Date: Mon, 11 Aug 2025 18:53:39 +0300 Subject: [PATCH 04/15] fix conditional usage for argo DAG --- metaflow/plugins/argo/argo_workflows.py | 66 +++++++++++++++---------- 1 file changed, 41 insertions(+), 25 deletions(-) diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index 51f7c678adb..a2770e9687c 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -928,7 +928,6 @@ def _visit( templates=None, dag_tasks=None, parent_foreach=None, - visited_nodes=None, ): # Returns Tuple[List[Template], List[DAGTask]] """ """ # Every for-each node results in a separate subDAG and an equivalent @@ -938,23 +937,11 @@ def _visit( # of the for-each node. # Emit if we have reached the end of the sub workflow - if visited_nodes is None: - visited_nodes = set() if dag_tasks is None: - print("RESET DAG_TASKS") dag_tasks = [] if templates is None: - print("RESET TEMPLATES") templates = [] - # Break early if we have reached a node we already visited. Happens when parsing through all conditional branches of split-switch - if node.name in visited_nodes: - print(f"BROKE EARLY on step :{node.name}") - return templates, dag_tasks - else: - print(f"added to visited :{node.name}") - visited_nodes.add(node.name) - if exit_node is not None and exit_node is node.name: return templates, dag_tasks if node.name == "start": @@ -1096,15 +1083,42 @@ def _visit( ] ) + conditional_deps = [ + "%s.Succeeded" % self._sanitize(in_func) + for in_func in node.in_funcs + if self.graph[in_func].is_conditional + ] + required_deps = [ + "%s.Succeeded" % self._sanitize(in_func) + for in_func in node.in_funcs + if not self.graph[in_func].is_conditional + ] + both_conditions = required_deps and conditional_deps + + depends_str = "{required}{_and}{conditional}".format( + required=("(%s)" if both_conditions else "%s") + % " && ".join(required_deps), + _and=" && " if both_conditions else "", + conditional=("(%s)" if both_conditions else "%s") + % " || ".join(conditional_deps), + ) dag_task = ( DAGTask(self._sanitize(node.name)) - .dependencies( - [self._sanitize(in_func) for in_func in node.in_funcs] - ) + .depends(depends_str) .template(self._sanitize(node.name)) .arguments(Arguments().parameters(parameters)) ) + # Add conditional if this is the first step in a conditional branch + if node.is_conditional and not any( + self.graph[in_func].is_conditional for in_func in node.in_funcs + ): + in_func = node.in_funcs[0] + dag_task.when( + "{{tasks.%s.outputs.parameters.switch-step}}==%s" + % (self._sanitize(in_func), node.name) + ) + dag_tasks.append(dag_task) # End the workflow if we have reached the end of the flow if node.type == "end": @@ -1133,24 +1147,21 @@ def _visit( visited_nodes, ) elif node.type == "split-switch": - # Traverse all branches of a switch split. This should work as all branches lead to 'exit_node' - for n in node.out_funcs[:-1]: + for n in node.out_funcs: _visit( self.graph[n], - exit_node, + node.conditional_join, templates, dag_tasks, parent_foreach, - visited_nodes, ) return _visit( - self.graph[node.out_funcs[-1:][0]], + self.graph[node.conditional_join], exit_node, templates, dag_tasks, parent_foreach, - visited_nodes, ) # For foreach nodes generate a new sub DAGTemplate # We do this for "regular" foreaches (ie. `self.next(self.a, foreach=)`) @@ -1231,7 +1242,6 @@ def _visit( templates, [], node.name, - visited_nodes, ) # How do foreach's work on Argo: @@ -1350,7 +1360,6 @@ def _visit( templates, dag_tasks, parent_foreach, - visited_nodes, ) # For linear nodes continue traversing to the next node if node.type in ("linear", "join", "start"): @@ -1360,7 +1369,6 @@ def _visit( templates, dag_tasks, parent_foreach, - visited_nodes, ) else: raise ArgoWorkflowsException( @@ -4023,6 +4031,10 @@ def dependencies(self, dependencies): self.payload["dependencies"] = dependencies return self + def depends(self, depends: str): + self.payload["depends"] = depends + return self + def template(self, template): # Template reference self.payload["template"] = template @@ -4034,6 +4046,10 @@ def inline(self, template): self.payload["inline"] = template.to_json() return self + def when(self, when: str): + self.payload["when"] = when + return self + def with_param(self, with_param): self.payload["withParam"] = with_param return self From d0a03fd8f6370913a74f19f488b355b7e754fb01 Mon Sep 17 00:00:00 2001 From: Sakari Ikonen Date: Tue, 12 Aug 2025 12:27:20 +0300 Subject: [PATCH 05/15] introduce script for parsing conditional input paths. rename and introduce more properties to graph --- metaflow/graph.py | 13 ++++++++++--- metaflow/plugins/argo/argo_workflows.py | 9 +++++++-- .../plugins/argo/conditional_input_paths.py | 18 ++++++++++++++++++ 3 files changed, 35 insertions(+), 5 deletions(-) create mode 100644 metaflow/plugins/argo/conditional_input_paths.py diff --git a/metaflow/graph.py b/metaflow/graph.py index fd56d8000fb..26cb4ec2470 100644 --- a/metaflow/graph.py +++ b/metaflow/graph.py @@ -81,9 +81,15 @@ def __init__( self.split_parents = [] self.split_branches = [] self.matching_join = None + # Conditional info, also populated in _traverse_graph self.is_conditional = False # will this node always be executed, or is it in a conditional branch? - self.conditional_branch = [] - self.conditional_join = None # Node where conditional branches end, and further nodes always execute. + self.is_conditional_join = ( + False # Does this node 'join' some set of conditional branches? + ) + self.conditional_branch = ( + [] + ) # All the steps leading to this node that depends on a condition, starting from the split-switch + self.conditional_end_node = None # Node where conditional branches end, and further nodes always execute. # these attributes are populated by _postprocess self.is_inside_foreach = False @@ -358,6 +364,7 @@ def traverse( ) if is_conditional_join: + node.is_conditional_join = True conditional_root_nodes = root_nodes # we are in a conditional branch if we have conditional root nodes left open, and @@ -374,7 +381,7 @@ def traverse( for in_func in node.in_funcs for step in self[in_func].conditional_branch ): - self[n].conditional_join = node.name + self[n].conditional_end_node = node.name node.is_conditional = is_in_conditional_branch diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index a2770e9687c..dc80754fd9a 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -1150,14 +1150,14 @@ def _visit( for n in node.out_funcs: _visit( self.graph[n], - node.conditional_join, + node.conditional_end_node, templates, dag_tasks, parent_foreach, ) return _visit( - self.graph[node.conditional_join], + self.graph[node.conditional_end_node], exit_node, templates, dag_tasks, @@ -1610,6 +1610,11 @@ def _container_templates(self): ] ) input_paths = "%s/_parameters/%s" % (run_id, task_id_params) + elif node.is_conditional_join: + input_paths = ( + "$(python -m metaflow.plugins.argo.conditional_input_paths %s)" + % input_paths + ) elif ( node.type == "join" and self.graph[node.split_parents[-1]].type == "foreach" diff --git a/metaflow/plugins/argo/conditional_input_paths.py b/metaflow/plugins/argo/conditional_input_paths.py new file mode 100644 index 00000000000..b224faf35f8 --- /dev/null +++ b/metaflow/plugins/argo/conditional_input_paths.py @@ -0,0 +1,18 @@ +from math import inf +import sys +from metaflow.util import decompress_list, compress_list + + +def generate_input_paths(input_paths): + # => run_id/step/:foo,bar + paths = decompress_list(input_paths) + + # some of the paths are going to be malformed due to never having executed per conditional. + # strip these out of the list. + + trimmed = [path for path in paths if not "{{" in path] + return compress_list(trimmed, zlibmin=inf) + + +if __name__ == "__main__": + print(generate_input_paths(sys.argv[1])) From 2d8de9212c34c49a6f6826ad1228fa8e62fa95cf Mon Sep 17 00:00:00 2001 From: Sakari Ikonen Date: Wed, 13 Aug 2025 12:45:57 +0300 Subject: [PATCH 06/15] cleanup --- metaflow/plugins/argo/argo_workflows.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index dc80754fd9a..b5e49d6c23f 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -1136,7 +1136,6 @@ def _visit( templates, dag_tasks, parent_foreach, - visited_nodes, ) return _visit( self.graph[node.matching_join], @@ -1144,7 +1143,6 @@ def _visit( templates, dag_tasks, parent_foreach, - visited_nodes, ) elif node.type == "split-switch": for n in node.out_funcs: From dca2f378050048b882700db56053c0d07b340b61 Mon Sep 17 00:00:00 2001 From: Sakari Ikonen Date: Thu, 14 Aug 2025 23:31:33 +0300 Subject: [PATCH 07/15] reworking conditional graph parsing --- metaflow/graph.py | 107 +++++++++++++----------- metaflow/plugins/argo/argo_workflows.py | 4 +- 2 files changed, 60 insertions(+), 51 deletions(-) diff --git a/metaflow/graph.py b/metaflow/graph.py index 26cb4ec2470..1ef03d9ac24 100644 --- a/metaflow/graph.py +++ b/metaflow/graph.py @@ -81,17 +81,18 @@ def __init__( self.split_parents = [] self.split_branches = [] self.matching_join = None - # Conditional info, also populated in _traverse_graph + # these attributes are populated by _postprocess + self.is_inside_foreach = False + # Conditional info self.is_conditional = False # will this node always be executed, or is it in a conditional branch? + self.matching_conditional_join = None # which step joins the conditional branches. filled for split-switch only. self.is_conditional_join = ( False # Does this node 'join' some set of conditional branches? ) + self.conditional_parents = [] self.conditional_branch = ( [] ) # All the steps leading to this node that depends on a condition, starting from the split-switch - self.conditional_end_node = None # Node where conditional branches end, and further nodes always execute. - # these attributes are populated by _postprocess - self.is_inside_foreach = False def _expr_str(self, expr): return "%s.%s" % (expr.value.id, expr.attr) @@ -306,14 +307,54 @@ def _postprocess(self): if [f for f in foreaches if self.nodes[f].matching_join != node.name]: node.is_inside_foreach = True + # Fill in conditionals related info. + if node.conditional_parents: + # do the required postprocessing for anything requiring node.in_funcs + + # does this node close the latest conditional parent branches? + conditional_in_funcs = [ + in_func + for in_func in node.in_funcs + if self[in_func].conditional_branch + ] + closed_conditional_parents = [] + for last_split_switch in node.conditional_parents[::-1]: + # last_split_switch = node.conditional_parents[-1] + last_conditional_split_nodes = self[last_split_switch].out_funcs + # p needs to be in at least one conditional_branch for it to be closed. + if all( + any( + p in self[in_func].conditional_branch + for in_func in conditional_in_funcs + ) + for p in last_conditional_split_nodes + ): + closed_conditional_parents.append(last_split_switch) + + node.is_conditional_join = True + self[last_split_switch].matching_conditional_join = node.name + + # Did we close all conditionals? Then this branch and all its children are not conditional anymore. + if not [ + p + for p in node.conditional_parents + if p not in closed_conditional_parents + ]: + node.is_conditional = False + node.conditional_parents = [] + for p in node.out_funcs: + child = self[p] + child.is_conditional = False + child.conditional_parents = [] + def _traverse_graph(self): def traverse( - node, + node: DAGNode, seen, split_parents, split_branches, conditional_branch: List[str], - conditional_root_nodes: Optional[List[List[str]]] = None, + conditional_parents: Optional[List[str]] = None, ): add_split_branch = False try: @@ -330,12 +371,11 @@ def traverse( node.split_parents = split_parents node.split_branches = split_branches - conditional_branch = conditional_branch + [node.name] - node.conditional_branch = conditional_branch - conditional_root_nodes = ( - [node.out_funcs] - if not conditional_root_nodes - else conditional_root_nodes + [node.out_funcs] + node.conditional_branch = conditional_branch + [node.name] + node.conditional_parents = ( + [node.name] + if not conditional_parents + else conditional_parents + [node.name] ) elif node.type == "join": # ignore joins without splits @@ -349,41 +389,10 @@ def traverse( node.split_parents = split_parents node.split_branches = split_branches - if conditional_root_nodes and not node.type == "split-switch": - conditional_branch = conditional_branch + [node.name] - node.conditional_branch = conditional_branch - # Multiple cases for conditional branching. TODO: describe the structure - # 1. we are in only one conditional branch - # 2. we are in a nested conditional branch - - *root_nodes, last_root_nodes = conditional_root_nodes - # Check if the node is joining all of the conditional root nodes branches. - is_conditional_join = all( - any(p in last_root_nodes for p in self[in_func].conditional_branch) - for in_func in node.in_funcs - ) - - if is_conditional_join: - node.is_conditional_join = True - conditional_root_nodes = root_nodes - - # we are in a conditional branch if we have conditional root nodes left open, and - # we did not join the most recent conditional branches. - is_in_conditional_branch = ( - bool(conditional_root_nodes) and not is_conditional_join - ) - - if not is_in_conditional_branch: - conditional_branch = [] - # add the conditional join step info - for n in set( - step - for in_func in node.in_funcs - for step in self[in_func].conditional_branch - ): - self[n].conditional_end_node = node.name - - node.is_conditional = is_in_conditional_branch + if conditional_parents and not node.type == "split-switch": + node.conditional_parents = conditional_parents + node.conditional_branch = conditional_branch + [node.name] + node.is_conditional = True for n in node.out_funcs: # graph may contain loops - ignore them @@ -397,8 +406,8 @@ def traverse( seen + [n], split_parents, split_branches + ([n] if add_split_branch else []), - conditional_branch, - conditional_root_nodes, + node.conditional_branch, + node.conditional_parents, ) if "start" in self: diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index b5e49d6c23f..dfe8b681538 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -1148,14 +1148,14 @@ def _visit( for n in node.out_funcs: _visit( self.graph[n], - node.conditional_end_node, + node.matching_conditional_join, templates, dag_tasks, parent_foreach, ) return _visit( - self.graph[node.conditional_end_node], + self.graph[node.matching_conditional_join], exit_node, templates, dag_tasks, From 511bd5101fa6a5c5855eced6e1647a328c2a93c5 Mon Sep 17 00:00:00 2001 From: Sakari Ikonen Date: Fri, 15 Aug 2025 02:29:11 +0300 Subject: [PATCH 08/15] fix foreaches --- metaflow/plugins/argo/argo_workflows.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index dfe8b681538..be1db3c32cd 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -1188,7 +1188,7 @@ def _visit( # foreach_task = ( DAGTask(foreach_template_name) - .dependencies([self._sanitize(node.name)]) + .depends(f"{self._sanitize(node.name)}.Succeeded") .template(foreach_template_name) .arguments( Arguments().parameters( @@ -1233,6 +1233,15 @@ def _visit( % self._sanitize(node.name) ) ) + # Add conditional if this is the first step in a conditional branch + if node.is_conditional and not any( + self.graph[in_func].is_conditional for in_func in node.in_funcs + ): + in_func = node.in_funcs[0] + foreach_task.when( + "{{tasks.%s.outputs.parameters.switch-step}}==%s" + % (self._sanitize(in_func), node.name) + ) dag_tasks.append(foreach_task) templates, dag_tasks_1 = _visit( self.graph[node.out_funcs[0]], @@ -1309,7 +1318,7 @@ def _visit( join_foreach_task = ( DAGTask(self._sanitize(self.graph[node.matching_join].name)) .template(self._sanitize(self.graph[node.matching_join].name)) - .dependencies([foreach_template_name]) + .depends(f"{foreach_template_name}.Succeeded") .arguments( Arguments().parameters( ( From 0a4f3c4c1fcf20e927296b405a72268ce6212b50 Mon Sep 17 00:00:00 2001 From: Sakari Ikonen Date: Fri, 15 Aug 2025 03:03:42 +0300 Subject: [PATCH 09/15] cleanup --- metaflow/plugins/argo/argo_workflows.py | 1 - 1 file changed, 1 deletion(-) diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index be1db3c32cd..d4adcc51548 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -1905,7 +1905,6 @@ def _container_templates(self): outputs = [Parameter("task-id").valueFrom({"path": "/mnt/out/task_id"})] # If this step is a split-switch one, we need to output the switch step name - # Note we can not use node.type for this, as the start step can also be a switching one if node.type == "split-switch": outputs.append( Parameter("switch-step").valueFrom({"path": "/mnt/out/switch_step"}) From 923f1d12bf3fa495b29e635c0044910da0495094 Mon Sep 17 00:00:00 2001 From: Sakari Ikonen Date: Fri, 15 Aug 2025 14:45:04 +0300 Subject: [PATCH 10/15] fix argo foreach template task-id output for conditional steps --- metaflow/plugins/argo/argo_workflows.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index d4adcc51548..7436da59574 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -1285,6 +1285,21 @@ def _visit( self.graph[node.matching_join].in_funcs[0] ) } + if not self.graph[ + node.matching_join + ].is_conditional_join + else + # Note: If the nodes leading to the join are conditional, then we need to use an expression to pick the outputs from the task that executed. + # ref for operators: https://github.com/expr-lang/expr/blob/master/docs/language-definition.md + { + "expression": "get((%s)?.outputs?.parameters, 'task-id')" + % " ?? ".join( + f"tasks['{self._sanitize(func)}']" + for func in self.graph[ + node.matching_join + ].in_funcs + ) + } ) ] if not node.parallel_foreach From 16f492f0b34abb733e723fc2f4b1c25dcbc10841 Mon Sep 17 00:00:00 2001 From: Sakari Ikonen Date: Fri, 15 Aug 2025 18:25:33 +0300 Subject: [PATCH 11/15] WIP: fix foreach pathspecs for join step with conditional parents --- metaflow/plugins/argo/argo_workflows.py | 41 +++++++++++++++++++------ 1 file changed, 31 insertions(+), 10 deletions(-) diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index 7436da59574..25bf81b2c81 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -1292,15 +1292,32 @@ def _visit( # Note: If the nodes leading to the join are conditional, then we need to use an expression to pick the outputs from the task that executed. # ref for operators: https://github.com/expr-lang/expr/blob/master/docs/language-definition.md { - "expression": "get((%s)?.outputs?.parameters, 'task-id')" + "expression": "get((%s)?.parameters, 'task-id')" % " ?? ".join( - f"tasks['{self._sanitize(func)}']" + f"tasks['{self._sanitize(func)}']?.outputs" for func in self.graph[ node.matching_join ].in_funcs ) } - ) + ), + # Add the out step for all foreach templates to keep things simpler. + # This is used to be able to create the input-paths correctly for the join step + Parameter("foreach-out-step").valueFrom( + { + "expression": "filter([%s], {#[1]=='Succeeded'})[0][0]" + % ",".join( + '["%s", tasks["%s"].status]' + % ( + self._sanitize(func), + self._sanitize(func), + ) + for func in self.graph[ + node.matching_join + ].in_funcs + ) + } + ), ] if not node.parallel_foreach else [ @@ -1346,6 +1363,12 @@ def _visit( "{{tasks.%s.outputs.parameters.split-cardinality}}" % self._sanitize(node.name) ), + # Only pick the output step from the first iteration of the foreach task, as it should be identical for all. + # TODO: This still needs fixing. + Parameter("foreach-out-step").value( + "{{= toJson(tasks['%s'].outputs.parameters['foreach-out-step'])[0] }}" + % foreach_template_name + ), ] if not node.parallel_foreach else [ @@ -1632,7 +1655,7 @@ def _container_templates(self): ] ) input_paths = "%s/_parameters/%s" % (run_id, task_id_params) - elif node.is_conditional_join: + elif node.is_conditional_join and not node.type == "join": input_paths = ( "$(python -m metaflow.plugins.argo.conditional_input_paths %s)" % input_paths @@ -1647,11 +1670,8 @@ def _container_templates(self): ) if not self.graph[node.split_parents[-1]].parallel_foreach: input_paths = ( - "$(python -m metaflow.plugins.argo.generate_input_paths %s {{workflow.creationTimestamp}} %s {{inputs.parameters.split-cardinality}})" - % ( - foreach_step, - input_paths, - ) + "$(python -m metaflow.plugins.argo.generate_input_paths {{inputs.parameters.foreach-out-step}} {{workflow.creationTimestamp}} %s {{inputs.parameters.split-cardinality}})" + % (input_paths,) ) else: # Handle @parallel where output from volume mount isn't accessible @@ -1883,8 +1903,9 @@ def _container_templates(self): [Parameter("num-parallel"), Parameter("task-id-entropy")] ) else: - # append this only for joins of foreaches, not static splits + # append these only for joins of foreaches, not static splits inputs.append(Parameter("split-cardinality")) + inputs.append(Parameter("foreach-out-step")) # check if the node is a @parallel node. elif node.parallel_step: inputs.extend( From 2c1d760f9d9e6d3aba24a2df2ea3030d683d9ff0 Mon Sep 17 00:00:00 2001 From: Sakari Ikonen Date: Fri, 15 Aug 2025 22:24:56 +0300 Subject: [PATCH 12/15] fix static joins and conditional_join again --- metaflow/plugins/argo/argo_workflows.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index 25bf81b2c81..639d1b1e01b 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -1655,7 +1655,11 @@ def _container_templates(self): ] ) input_paths = "%s/_parameters/%s" % (run_id, task_id_params) - elif node.is_conditional_join and not node.type == "join": + # Only for static joins and conditional_joins + elif node.is_conditional_join and not ( + node.type == "join" + and self.graph[node.split_parents[-1]].type == "foreach" + ): input_paths = ( "$(python -m metaflow.plugins.argo.conditional_input_paths %s)" % input_paths From 158f725ad9717b3f3056d9da0cc780eba1b7dd61 Mon Sep 17 00:00:00 2001 From: Sakari Ikonen Date: Sat, 16 Aug 2025 01:55:01 +0300 Subject: [PATCH 13/15] revert foreach case parsing and opt for exception for now. fix graph parsing for odd cases --- metaflow/graph.py | 6 ++++ metaflow/plugins/argo/argo_workflows.py | 42 +++++++++---------------- 2 files changed, 20 insertions(+), 28 deletions(-) diff --git a/metaflow/graph.py b/metaflow/graph.py index 1ef03d9ac24..3d199959359 100644 --- a/metaflow/graph.py +++ b/metaflow/graph.py @@ -311,6 +311,12 @@ def _postprocess(self): if node.conditional_parents: # do the required postprocessing for anything requiring node.in_funcs + # check that in previous parsing we have not closed all conditional in_funcs. + # If so, this step can not be conditional either + node.is_conditional = any( + self[in_func].is_conditional or self[in_func].type == "split-switch" + for in_func in node.in_funcs + ) # does this node close the latest conditional parent branches? conditional_in_funcs = [ in_func diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index 639d1b1e01b..67e9d659825 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -1110,8 +1110,9 @@ def _visit( ) # Add conditional if this is the first step in a conditional branch - if node.is_conditional and not any( - self.graph[in_func].is_conditional for in_func in node.in_funcs + if ( + node.is_conditional + and self.graph[node.in_funcs[0]].type == "split-switch" ): in_func = node.in_funcs[0] dag_task.when( @@ -1301,23 +1302,6 @@ def _visit( ) } ), - # Add the out step for all foreach templates to keep things simpler. - # This is used to be able to create the input-paths correctly for the join step - Parameter("foreach-out-step").valueFrom( - { - "expression": "filter([%s], {#[1]=='Succeeded'})[0][0]" - % ",".join( - '["%s", tasks["%s"].status]' - % ( - self._sanitize(func), - self._sanitize(func), - ) - for func in self.graph[ - node.matching_join - ].in_funcs - ) - } - ), ] if not node.parallel_foreach else [ @@ -1363,12 +1347,6 @@ def _visit( "{{tasks.%s.outputs.parameters.split-cardinality}}" % self._sanitize(node.name) ), - # Only pick the output step from the first iteration of the foreach task, as it should be identical for all. - # TODO: This still needs fixing. - Parameter("foreach-out-step").value( - "{{= toJson(tasks['%s'].outputs.parameters['foreach-out-step'])[0] }}" - % foreach_template_name - ), ] if not node.parallel_foreach else [ @@ -1668,14 +1646,23 @@ def _container_templates(self): node.type == "join" and self.graph[node.split_parents[-1]].type == "foreach" ): + # foreach-joins straight out of conditional branches are not yet supported + if node.is_conditional_join: + raise ArgoWorkflowsException( + "Foreach steps with a conditional step as the last one are not yet supported with Argo Workflows." + "For now, you can add a merging step after the conditional ones that will be then joined by the foreach-join" + ) # Set aggregated input-paths for a for-each join foreach_step = next( n for n in node.in_funcs if self.graph[n].is_inside_foreach ) if not self.graph[node.split_parents[-1]].parallel_foreach: input_paths = ( - "$(python -m metaflow.plugins.argo.generate_input_paths {{inputs.parameters.foreach-out-step}} {{workflow.creationTimestamp}} %s {{inputs.parameters.split-cardinality}})" - % (input_paths,) + "$(python -m metaflow.plugins.argo.generate_input_paths %s {{workflow.creationTimestamp}} %s {{inputs.parameters.split-cardinality}})" + % ( + foreach_step, + input_paths, + ) ) else: # Handle @parallel where output from volume mount isn't accessible @@ -1909,7 +1896,6 @@ def _container_templates(self): else: # append these only for joins of foreaches, not static splits inputs.append(Parameter("split-cardinality")) - inputs.append(Parameter("foreach-out-step")) # check if the node is a @parallel node. elif node.parallel_step: inputs.extend( From eb8192e81f7dd65c2c2a8c08ee148aa1f42ff9cd Mon Sep 17 00:00:00 2001 From: Sakari Ikonen Date: Mon, 18 Aug 2025 15:30:16 +0300 Subject: [PATCH 14/15] add a script for fetching and saving input-paths to file for argo workflows conditional joins --- metaflow/plugins/argo/argo_workflows.py | 25 ++++++-------- .../plugins/argo/create_input_paths_file.py | 34 +++++++++++++++++++ 2 files changed, 44 insertions(+), 15 deletions(-) create mode 100644 metaflow/plugins/argo/create_input_paths_file.py diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index 67e9d659825..1fb21a08636 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -1443,6 +1443,7 @@ def _container_templates(self): task_idx = "" input_paths = "" + input_paths_filename = "" root_input = None # export input_paths as it is used multiple times in the container script # and we do not want to repeat the values. @@ -1633,25 +1634,15 @@ def _container_templates(self): ] ) input_paths = "%s/_parameters/%s" % (run_id, task_id_params) - # Only for static joins and conditional_joins - elif node.is_conditional_join and not ( - node.type == "join" - and self.graph[node.split_parents[-1]].type == "foreach" - ): - input_paths = ( - "$(python -m metaflow.plugins.argo.conditional_input_paths %s)" - % input_paths + elif node.is_conditional_join: + input_paths_filename = ( + "$(python -m metaflow.plugins.argo.create_input_paths_file '%s' '%s')" + % (",".join(node.in_funcs), f"{self.flow.name}/{run_id}") ) elif ( node.type == "join" and self.graph[node.split_parents[-1]].type == "foreach" ): - # foreach-joins straight out of conditional branches are not yet supported - if node.is_conditional_join: - raise ArgoWorkflowsException( - "Foreach steps with a conditional step as the last one are not yet supported with Argo Workflows." - "For now, you can add a merging step after the conditional ones that will be then joined by the foreach-join" - ) # Set aggregated input-paths for a for-each join foreach_step = next( n for n in node.in_funcs if self.graph[n].is_inside_foreach @@ -1680,7 +1671,11 @@ def _container_templates(self): "--task-id %s" % task_id, "--retry-count %s" % retry_count, "--max-user-code-retries %d" % user_code_retries, - "--input-paths %s" % input_paths, + ( + "--input-paths-filename %s" % input_paths_filename + if input_paths_filename + else "--input-paths %s" % input_paths + ), ] if node.parallel_step: step.append( diff --git a/metaflow/plugins/argo/create_input_paths_file.py b/metaflow/plugins/argo/create_input_paths_file.py new file mode 100644 index 00000000000..619b2b8e9bc --- /dev/null +++ b/metaflow/plugins/argo/create_input_paths_file.py @@ -0,0 +1,34 @@ +import sys +from metaflow import Run +from metaflow.util import compress_list +from tempfile import NamedTemporaryFile + +# This utility uses the Metadata Service to fetch completed tasks for the provided step_names for a specific run +# and writes them out to a file, returning the file path as a result. +# This is required due to Foreach split output steps not being deterministic anymore after introducing conditional branching to Metaflow, as during graph parsing we now only know the set of possible steps that leads to the executing step. + + +def fetch_input_paths(step_names, run_pathspec): + steps = step_names.split(",") + run = Run(run_pathspec, _namespace_check=False) + + input_paths = [] + for step in steps: + try: + # for input paths we require the pathspec without the Flow name + input_paths.extend(f"{run.id}/{step}/{task.id}" for task in run[step]) + except KeyError: + # a step might not have ever executed due to it being conditional. + pass + + return input_paths + + +if __name__ == "__main__": + input_paths = fetch_input_paths(sys.argv[1], sys.argv[2]) + # we use the Metaflow internal compress_list due to --input-paths-filename processing relying on decompress_list. + compressed = compress_list(input_paths) + + with NamedTemporaryFile(delete=False) as f: + f.write(compressed.encode("utf-8")) + print(f.name) From 4a736e72ad0ef9b36c9a1643545afa1ea3c0b87f Mon Sep 17 00:00:00 2001 From: Sakari Ikonen Date: Mon, 18 Aug 2025 16:06:44 +0300 Subject: [PATCH 15/15] remove unused conditionals input_path script --- .../plugins/argo/conditional_input_paths.py | 18 ------------------ 1 file changed, 18 deletions(-) delete mode 100644 metaflow/plugins/argo/conditional_input_paths.py diff --git a/metaflow/plugins/argo/conditional_input_paths.py b/metaflow/plugins/argo/conditional_input_paths.py deleted file mode 100644 index b224faf35f8..00000000000 --- a/metaflow/plugins/argo/conditional_input_paths.py +++ /dev/null @@ -1,18 +0,0 @@ -from math import inf -import sys -from metaflow.util import decompress_list, compress_list - - -def generate_input_paths(input_paths): - # => run_id/step/:foo,bar - paths = decompress_list(input_paths) - - # some of the paths are going to be malformed due to never having executed per conditional. - # strip these out of the list. - - trimmed = [path for path in paths if not "{{" in path] - return compress_list(trimmed, zlibmin=inf) - - -if __name__ == "__main__": - print(generate_input_paths(sys.argv[1]))