1
1
import types
2
+ from collections import defaultdict
2
3
3
4
import graphviz as gv
4
5
@@ -25,18 +26,64 @@ def get_workflow(name):
25
26
26
27
27
28
class NoDashDiGraph (gv .Digraph ):
28
- """Like `.graphviz.Digraph` but removes underscores from labels."""
29
+ """
30
+ Like `.graphviz.Digraph` but with unique nodes and edges.
31
+
32
+ Nodes and edges are unique and their attributes will be overridden
33
+ should the same node or edge be added twice. Nodes are unique by name
34
+ and edges unique by head and tail.
35
+
36
+ Underscores are replaced with whitespaces from identifiers.
37
+ """
29
38
30
39
def __init__ (self , * args , ** kwargs ):
31
- self ._edges = []
40
+ self ._nodes = defaultdict (dict )
41
+ self ._edges = defaultdict (dict )
32
42
super ().__init__ (* args , ** kwargs )
33
43
34
- def edge (self , tail_name , head_name , label = None , _attributes = None , ** attrs ):
35
- if not (tail_name , head_name ) in self ._edges :
36
- self ._edges .append ((tail_name , head_name ))
37
- super ().edge (
38
- tail_name , head_name , label = label , _attributes = _attributes , ** attrs
39
- )
44
+ def __iter__ (self , subgraph = False ):
45
+ """Yield the DOT source code line by line (as graph or subgraph)."""
46
+ if self .comment :
47
+ yield self ._comment % self .comment
48
+
49
+ if subgraph :
50
+ if self .strict :
51
+ raise ValueError ("subgraphs cannot be strict" )
52
+ head = self ._subgraph if self .name else self ._subgraph_plain
53
+ else :
54
+ head = self ._head_strict if self .strict else self ._head
55
+ yield head % (self ._quote (self .name ) + " " if self .name else "" )
56
+
57
+ for kw in ("graph" , "node" , "edge" ):
58
+ attrs = getattr (self , "%s_attr" % kw )
59
+ if attrs :
60
+ yield self ._attr % (kw , self ._attr_list (None , attrs ))
61
+
62
+ yield from self .body
63
+
64
+ for name , attrs in sorted (self ._nodes .items ()):
65
+ name = self ._quote (name )
66
+ label = attrs .pop ("label" , None )
67
+ _attributes = attrs .pop ("_attributes" , None )
68
+ attr_list = self ._attr_list (label , attrs , _attributes )
69
+ yield self ._node % (name , attr_list )
70
+
71
+ for edge , attrs in sorted (self ._edges .items ()):
72
+ head_name , tail_name = edge
73
+ tail_name = self ._quote_edge (tail_name )
74
+ head_name = self ._quote_edge (head_name )
75
+ label = attrs .pop ("label" , None )
76
+ _attributes = attrs .pop ("_attributes" , None )
77
+ attr_list = self ._attr_list (label , attrs , _attributes )
78
+ yield self ._edge % (head_name , tail_name , attr_list )
79
+
80
+ yield self ._tail
81
+
82
+ def node (self , name , ** attrs ):
83
+ self ._nodes [name ] = attrs
84
+
85
+ def edge (self , tail_name , head_name , ** attrs ):
86
+ self ._edges [(tail_name , head_name )] = attrs
40
87
41
88
@staticmethod
42
89
def _quote (identifier , * args , ** kwargs ):
0 commit comments