Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 80 additions & 2 deletions metaflow/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import re

from itertools import chain
from typing import List, Optional


from .util import to_pod
Expand Down Expand Up @@ -82,6 +83,16 @@ def __init__(
self.matching_join = None
# these attributes are populated by _postprocess
self.is_inside_foreach = False
# Conditional info
self.is_conditional = False # will this node always be executed, or is it in a conditional branch?
self.matching_conditional_join = None # which step joins the conditional branches. filled for split-switch only.
self.is_conditional_join = (
False # Does this node 'join' some set of conditional branches?
)
self.conditional_parents = []
self.conditional_branch = (
[]
) # All the steps leading to this node that depends on a condition, starting from the split-switch

def _expr_str(self, expr):
return "%s.%s" % (expr.value.id, expr.attr)
Expand Down Expand Up @@ -296,8 +307,61 @@ def _postprocess(self):
if [f for f in foreaches if self.nodes[f].matching_join != node.name]:
node.is_inside_foreach = True

# Fill in conditionals related info.
if node.conditional_parents:
# do the required postprocessing for anything requiring node.in_funcs

# check that in previous parsing we have not closed all conditional in_funcs.
# If so, this step can not be conditional either
node.is_conditional = any(
self[in_func].is_conditional or self[in_func].type == "split-switch"
for in_func in node.in_funcs
)
# does this node close the latest conditional parent branches?
conditional_in_funcs = [
in_func
for in_func in node.in_funcs
if self[in_func].conditional_branch
]
closed_conditional_parents = []
for last_split_switch in node.conditional_parents[::-1]:
# last_split_switch = node.conditional_parents[-1]
last_conditional_split_nodes = self[last_split_switch].out_funcs
# p needs to be in at least one conditional_branch for it to be closed.
if all(
any(
p in self[in_func].conditional_branch
for in_func in conditional_in_funcs
)
for p in last_conditional_split_nodes
):
closed_conditional_parents.append(last_split_switch)

node.is_conditional_join = True
self[last_split_switch].matching_conditional_join = node.name

# Did we close all conditionals? Then this branch and all its children are not conditional anymore.
if not [
p
for p in node.conditional_parents
if p not in closed_conditional_parents
]:
node.is_conditional = False
node.conditional_parents = []
for p in node.out_funcs:
child = self[p]
child.is_conditional = False
child.conditional_parents = []

def _traverse_graph(self):
def traverse(node, seen, split_parents, split_branches):
def traverse(
node: DAGNode,
seen,
split_parents,
split_branches,
conditional_branch: List[str],
conditional_parents: Optional[List[str]] = None,
):
add_split_branch = False
try:
self.sorted_nodes.remove(node.name)
Expand All @@ -312,6 +376,13 @@ def traverse(node, seen, split_parents, split_branches):
elif node.type == "split-switch":
node.split_parents = split_parents
node.split_branches = split_branches

node.conditional_branch = conditional_branch + [node.name]
node.conditional_parents = (
[node.name]
if not conditional_parents
else conditional_parents + [node.name]
)
elif node.type == "join":
# ignore joins without splits
if split_parents:
Expand All @@ -324,6 +395,11 @@ def traverse(node, seen, split_parents, split_branches):
node.split_parents = split_parents
node.split_branches = split_branches

if conditional_parents and not node.type == "split-switch":
node.conditional_parents = conditional_parents
node.conditional_branch = conditional_branch + [node.name]
node.is_conditional = True

for n in node.out_funcs:
# graph may contain loops - ignore them
if n not in seen:
Expand All @@ -336,10 +412,12 @@ def traverse(node, seen, split_parents, split_branches):
seen + [n],
split_parents,
split_branches + ([n] if add_split_branch else []),
node.conditional_branch,
node.conditional_parents,
)

if "start" in self:
traverse(self["start"], [], [], [])
traverse(self["start"], [], [], [], [])

# fix the order of in_funcs
for node in self.nodes.values():
Expand Down
118 changes: 104 additions & 14 deletions metaflow/plugins/argo/argo_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -941,19 +941,15 @@ def _visit(
dag_tasks = []
if templates is None:
templates = []

if exit_node is not None and exit_node is node.name:
return templates, dag_tasks
if node.name == "start":
# Start node has no dependencies.
dag_task = DAGTask(self._sanitize(node.name)).template(
self._sanitize(node.name)
)
if node.type == "split-switch":
raise ArgoWorkflowsException(
"Deploying flows with switch statement "
"to Argo Workflows is not supported currently."
)
elif (
if (
node.is_inside_foreach
and self.graph[node.in_funcs[0]].type == "foreach"
and not self.graph[node.in_funcs[0]].parallel_foreach
Expand Down Expand Up @@ -1087,15 +1083,43 @@ def _visit(
]
)

conditional_deps = [
"%s.Succeeded" % self._sanitize(in_func)
for in_func in node.in_funcs
if self.graph[in_func].is_conditional
]
required_deps = [
"%s.Succeeded" % self._sanitize(in_func)
for in_func in node.in_funcs
if not self.graph[in_func].is_conditional
]
both_conditions = required_deps and conditional_deps

depends_str = "{required}{_and}{conditional}".format(
required=("(%s)" if both_conditions else "%s")
% " && ".join(required_deps),
_and=" && " if both_conditions else "",
conditional=("(%s)" if both_conditions else "%s")
% " || ".join(conditional_deps),
)
dag_task = (
DAGTask(self._sanitize(node.name))
.dependencies(
[self._sanitize(in_func) for in_func in node.in_funcs]
)
.depends(depends_str)
.template(self._sanitize(node.name))
.arguments(Arguments().parameters(parameters))
)

# Add conditional if this is the first step in a conditional branch
if (
node.is_conditional
and self.graph[node.in_funcs[0]].type == "split-switch"
):
in_func = node.in_funcs[0]
dag_task.when(
"{{tasks.%s.outputs.parameters.switch-step}}==%s"
% (self._sanitize(in_func), node.name)
)

dag_tasks.append(dag_task)
# End the workflow if we have reached the end of the flow
if node.type == "end":
Expand All @@ -1121,6 +1145,23 @@ def _visit(
dag_tasks,
parent_foreach,
)
elif node.type == "split-switch":
for n in node.out_funcs:
_visit(
self.graph[n],
node.matching_conditional_join,
templates,
dag_tasks,
parent_foreach,
)

return _visit(
self.graph[node.matching_conditional_join],
exit_node,
templates,
dag_tasks,
parent_foreach,
)
# For foreach nodes generate a new sub DAGTemplate
# We do this for "regular" foreaches (ie. `self.next(self.a, foreach=)`)
elif node.type == "foreach":
Expand Down Expand Up @@ -1148,7 +1189,7 @@ def _visit(
#
foreach_task = (
DAGTask(foreach_template_name)
.dependencies([self._sanitize(node.name)])
.depends(f"{self._sanitize(node.name)}.Succeeded")
.template(foreach_template_name)
.arguments(
Arguments().parameters(
Expand Down Expand Up @@ -1193,6 +1234,15 @@ def _visit(
% self._sanitize(node.name)
)
)
# Add conditional if this is the first step in a conditional branch
if node.is_conditional and not any(
self.graph[in_func].is_conditional for in_func in node.in_funcs
):
in_func = node.in_funcs[0]
foreach_task.when(
"{{tasks.%s.outputs.parameters.switch-step}}==%s"
% (self._sanitize(in_func), node.name)
)
dag_tasks.append(foreach_task)
templates, dag_tasks_1 = _visit(
self.graph[node.out_funcs[0]],
Expand Down Expand Up @@ -1236,7 +1286,22 @@ def _visit(
self.graph[node.matching_join].in_funcs[0]
)
}
)
if not self.graph[
node.matching_join
].is_conditional_join
else
# Note: If the nodes leading to the join are conditional, then we need to use an expression to pick the outputs from the task that executed.
# ref for operators: https://github.com/expr-lang/expr/blob/master/docs/language-definition.md
{
"expression": "get((%s)?.parameters, 'task-id')"
% " ?? ".join(
f"tasks['{self._sanitize(func)}']?.outputs"
for func in self.graph[
node.matching_join
].in_funcs
)
}
),
]
if not node.parallel_foreach
else [
Expand Down Expand Up @@ -1269,7 +1334,7 @@ def _visit(
join_foreach_task = (
DAGTask(self._sanitize(self.graph[node.matching_join].name))
.template(self._sanitize(self.graph[node.matching_join].name))
.dependencies([foreach_template_name])
.depends(f"{foreach_template_name}.Succeeded")
.arguments(
Arguments().parameters(
(
Expand Down Expand Up @@ -1378,6 +1443,7 @@ def _container_templates(self):

task_idx = ""
input_paths = ""
input_paths_filename = ""
root_input = None
# export input_paths as it is used multiple times in the container script
# and we do not want to repeat the values.
Expand Down Expand Up @@ -1568,6 +1634,11 @@ def _container_templates(self):
]
)
input_paths = "%s/_parameters/%s" % (run_id, task_id_params)
elif node.is_conditional_join:
input_paths_filename = (
"$(python -m metaflow.plugins.argo.create_input_paths_file '%s' '%s')"
% (",".join(node.in_funcs), f"{self.flow.name}/{run_id}")
)
elif (
node.type == "join"
and self.graph[node.split_parents[-1]].type == "foreach"
Expand Down Expand Up @@ -1600,7 +1671,11 @@ def _container_templates(self):
"--task-id %s" % task_id,
"--retry-count %s" % retry_count,
"--max-user-code-retries %d" % user_code_retries,
"--input-paths %s" % input_paths,
(
"--input-paths-filename %s" % input_paths_filename
if input_paths_filename
else "--input-paths %s" % input_paths
),
]
if node.parallel_step:
step.append(
Expand Down Expand Up @@ -1814,7 +1889,7 @@ def _container_templates(self):
[Parameter("num-parallel"), Parameter("task-id-entropy")]
)
else:
# append this only for joins of foreaches, not static splits
# append these only for joins of foreaches, not static splits
inputs.append(Parameter("split-cardinality"))
# check if the node is a @parallel node.
elif node.parallel_step:
Expand Down Expand Up @@ -1849,6 +1924,13 @@ def _container_templates(self):
# are derived at runtime.
if not (node.name == "end" or node.parallel_step):
outputs = [Parameter("task-id").valueFrom({"path": "/mnt/out/task_id"})]

# If this step is a split-switch one, we need to output the switch step name
if node.type == "split-switch":
outputs.append(
Parameter("switch-step").valueFrom({"path": "/mnt/out/switch_step"})
)

if node.type == "foreach":
# Emit split cardinality from foreach task
outputs.append(
Expand Down Expand Up @@ -3981,6 +4063,10 @@ def dependencies(self, dependencies):
self.payload["dependencies"] = dependencies
return self

def depends(self, depends: str):
self.payload["depends"] = depends
return self

def template(self, template):
# Template reference
self.payload["template"] = template
Expand All @@ -3992,6 +4078,10 @@ def inline(self, template):
self.payload["inline"] = template.to_json()
return self

def when(self, when: str):
self.payload["when"] = when
return self

def with_param(self, with_param):
self.payload["withParam"] = with_param
return self
Expand Down
9 changes: 9 additions & 0 deletions metaflow/plugins/argo/argo_workflows_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,15 @@ def task_finished(
with open("/mnt/out/split_cardinality", "w") as file:
json.dump(flow._foreach_num_splits, file)

# For conditional branches we need to record the value of the switch to disk, in order to pass it as an
# output from the switching step to be used further down the DAG
if graph[step_name].type == "split-switch":
# TODO: A nicer way to access the chosen step?
_out_funcs, _ = flow._transition
chosen_step = _out_funcs[0]
with open("/mnt/out/switch_step", "w") as file:
file.write(chosen_step)

# For steps that have a `@parallel` decorator set to them, we will be relying on Jobsets
# to run the task. In this case, we cannot set anything in the
# `/mnt/out` directory, since such form of output mounts are not available to Jobset executions.
Expand Down
Loading