From a63412d5226f148a050a321ab3e73fbd964a1ed9 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Thu, 20 Nov 2025 12:31:55 -0500 Subject: [PATCH 1/4] basic cl --- doc/releases/changelog-dev.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index b135975266..831453990c 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -2,6 +2,9 @@

New features since last release

+* Compiled programs can be visualized. + [(#)]() + * Added ``catalyst.switch``, a qjit compatible, index-switch style control flow decorator. [(#2171)](https://github.com/PennyLaneAI/catalyst/pull/2171) From 82b023a066d1f65e3b9f37d43d9e1da793d20450 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Thu, 20 Nov 2025 12:58:15 -0500 Subject: [PATCH 2/4] Trigger CI From 248b7bd57361e812e99b6cb2a4c005e4ce03996f Mon Sep 17 00:00:00 2001 From: Andrija Paurevic <46359773+andrijapau@users.noreply.github.com> Date: Thu, 20 Nov 2025 16:13:56 -0500 Subject: [PATCH 3/4] feat: add `DAGBuilder` abstract base class and `PyDotDAGBuilder` concrete subclass (#2213) Copied over what was in https://github.com/PennyLaneAI/pennylane/pull/8626. [sc-103457] [sc-103456] --------- Co-authored-by: Mehrdad Malek <39844030+mehrdad2m@users.noreply.github.com> --- doc/releases/changelog-dev.md | 2 +- .../visualization/dag_builder.py | 100 ++++++ .../visualization/pydot_dag_builder.py | 207 ++++++++++++ .../visualization/test_dag_builder.py | 96 ++++++ .../visualization/test_pydot_dag_builder.py | 300 ++++++++++++++++++ requirements.txt | 1 + 6 files changed, 705 insertions(+), 1 deletion(-) create mode 100644 frontend/catalyst/python_interface/visualization/dag_builder.py create mode 100644 frontend/catalyst/python_interface/visualization/pydot_dag_builder.py create mode 100644 frontend/test/pytest/python_interface/visualization/test_dag_builder.py create mode 100644 frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index debb67c0a3..ee7c22c7e3 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -3,7 +3,7 @@

New features since last release

* Compiled programs can be visualized. - [(#)]() + [(#2213)](https://github.com/PennyLaneAI/catalyst/pull/2213) * Added ``catalyst.switch``, a qjit compatible, index-switch style control flow decorator. [(#2171)](https://github.com/PennyLaneAI/catalyst/pull/2171) diff --git a/frontend/catalyst/python_interface/visualization/dag_builder.py b/frontend/catalyst/python_interface/visualization/dag_builder.py new file mode 100644 index 0000000000..aeb8b217b6 --- /dev/null +++ b/frontend/catalyst/python_interface/visualization/dag_builder.py @@ -0,0 +1,100 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""File that defines the DAGBuilder abstract base class.""" + +from abc import ABC, abstractmethod +from typing import Any + + +class DAGBuilder(ABC): + """An abstract base class for building Directed Acyclic Graphs (DAGs). + + This class provides a simple interface with three core methods (`add_node`, `add_edge` and `add_cluster`). + You can override these methods to implement any backend, like `pydot` or `graphviz` or even `matplotlib`. + + Outputting your graph can be done by overriding `to_file` and `to_string`. + """ + + @abstractmethod + def add_node( + self, node_id: str, node_label: str, parent_graph_id: str | None = None, **node_attrs: Any + ) -> None: + """Add a single node to the graph. + + Args: + node_id (str): Unique node ID to identify this node. + node_label (str): The text to display on the node when rendered. + parent_graph_id (str | None): Optional ID of the cluster this node belongs to. + **node_attrs (Any): Any additional styling keyword arguments. + + """ + raise NotImplementedError + + @abstractmethod + def add_edge(self, from_node_id: str, to_node_id: str, **edge_attrs: Any) -> None: + """Add a single directed edge between nodes in the graph. + + Args: + from_node_id (str): The unique ID of the source node. + to_node_id (str): The unique ID of the destination node. + **edge_attrs (Any): Any additional styling keyword arguments. + + """ + raise NotImplementedError + + @abstractmethod + def add_cluster( + self, + cluster_id: str, + node_label: str | None = None, + parent_graph_id: str | None = None, + **cluster_attrs: Any, + ) -> None: + """Add a single cluster to the graph. + + A cluster is a specific type of subgraph where the nodes and edges contained + within it are visually and logically grouped. + + Args: + cluster_id (str): Unique cluster ID to identify this cluster. + node_label (str | None): The text to display on an information node within the cluster when rendered. + parent_graph_id (str | None): Optional ID of the cluster this cluster belongs to. + **cluster_attrs (Any): Any additional styling keyword arguments. + + """ + raise NotImplementedError + + @abstractmethod + def to_file(self, output_filename: str) -> None: + """Save the graph to a file. + + The implementation should ideally infer the output format + (e.g., 'png', 'svg') from this filename's extension. + + Args: + output_filename (str): Desired filename for the graph. + + """ + raise NotImplementedError + + @abstractmethod + def to_string(self) -> str: + """Return the graph as a string. + + This is typically used to get the graph's representation in a standard string format like DOT. + + Returns: + str: A string representation of the graph. + """ + raise NotImplementedError diff --git a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py new file mode 100644 index 0000000000..384537c93e --- /dev/null +++ b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py @@ -0,0 +1,207 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""File that defines the PyDotDAGBuilder subclass of DAGBuilder.""" + +import pathlib +from collections import ChainMap +from typing import Any + +from .dag_builder import DAGBuilder + +has_pydot = True +try: + import pydot +except ImportError: + has_pydot = False + + +class PyDotDAGBuilder(DAGBuilder): + """A Directed Acyclic Graph builder for the PyDot backend.""" + + def __init__( + self, + attrs: dict | None = None, + node_attrs: dict | None = None, + edge_attrs: dict | None = None, + cluster_attrs: dict | None = None, + ) -> None: + """Initialize PyDotDAGBuilder instance. + + Args: + attrs (dict | None): User default attributes to be used for all elements (nodes, edges, clusters) in the graph. + node_attrs (dict | None): User default attributes for a node. + edge_attrs (dict | None): User default attributes for an edge. + cluster_attrs (dict | None): User default attributes for a cluster. + + """ + # Initialize the pydot graph: + # - graph_type="digraph": Create a directed graph (edges have arrows). + # - rankdir="TB": Set layout direction from Top to Bottom. + # - compound="true": Allow edges to connect directly to clusters/subgraphs. + # - strict=True: Prevent duplicate edges (e.g., A -> B added twice). + self.graph: pydot.Dot = pydot.Dot( + graph_type="digraph", rankdir="TB", compound="true", strict=True + ) + # Create cache for easy look-up + self._subgraphs: dict[str, pydot.Graph] = {} + self._subgraphs["__base__"] = self.graph + + _default_attrs: dict = {"fontname": "Helvetica", "penwidth": 2} if attrs is None else attrs + self._default_node_attrs: dict = ( + { + **_default_attrs, + "shape": "ellipse", + "style": "filled", + "fillcolor": "lightblue", + "color": "lightblue4", + "penwidth": 3, + } + if node_attrs is None + else node_attrs + ) + self._default_edge_attrs: dict = ( + { + "color": "lightblue4", + "penwidth": 3, + } + if edge_attrs is None + else edge_attrs + ) + self._default_cluster_attrs: dict = ( + { + **_default_attrs, + "shape": "rectangle", + "style": "solid", + } + if cluster_attrs is None + else cluster_attrs + ) + + def add_node( + self, + node_id: str, + node_label: str, + parent_graph_id: str | None = None, + **node_attrs: Any, + ) -> None: + """Add a single node to the graph. + + Args: + node_id (str): Unique node ID to identify this node. + node_label (str): The text to display on the node when rendered. + parent_graph_id (str | None): Optional ID of the cluster this node belongs to. + **node_attrs (Any): Any additional styling keyword arguments. + + """ + # Use ChainMap so you don't need to construct a new dictionary + node_attrs = ChainMap(node_attrs, self._default_node_attrs) + node = pydot.Node(node_id, label=node_label, **node_attrs) + parent_graph_id = "__base__" if parent_graph_id is None else parent_graph_id + + self._subgraphs[parent_graph_id].add_node(node) + + def add_edge(self, from_node_id: str, to_node_id: str, **edge_attrs: Any) -> None: + """Add a single directed edge between nodes in the graph. + + Args: + from_node_id (str): The unique ID of the source node. + to_node_id (str): The unique ID of the destination node. + **edge_attrs (Any): Any additional styling keyword arguments. + + """ + # Use ChainMap so you don't need to construct a new dictionary + edge_attrs = ChainMap(edge_attrs, self._default_edge_attrs) + edge = pydot.Edge(from_node_id, to_node_id, **edge_attrs) + self.graph.add_edge(edge) + + def add_cluster( + self, + cluster_id: str, + node_label: str | None = None, + parent_graph_id: str | None = None, + **cluster_attrs: Any, + ) -> None: + """Add a single cluster to the graph. + + A cluster is a specific type of subgraph where the nodes and edges contained + within it are visually and logically grouped. + + Args: + cluster_id (str): Unique cluster ID to identify this cluster. + node_label (str | None): The text to display on the information node within the cluster when rendered. + parent_graph_id (str | None): Optional ID of the cluster this cluster belongs to. + **cluster_attrs (Any): Any additional styling keyword arguments. + + """ + # Use ChainMap so you don't need to construct a new dictionary + cluster_attrs = ChainMap(cluster_attrs, self._default_cluster_attrs) + cluster = pydot.Cluster(graph_name=cluster_id, **cluster_attrs) + + # Puts the label in a node within the cluster. + # Ensures that any edges connecting nodes through the cluster + # boundary don't block the label. + # ┌───────────┐ + # │ ┌───────┐ │ + # │ │ label │ │ + # │ └───────┘ │ + # │ │ + # └───────────┘ + if node_label: + node_id = f"{cluster_id}_info_node" + rank_subgraph = pydot.Subgraph() + node = pydot.Node( + node_id, + label=node_label, + shape="rectangle", + style="dashed", + fontname="Helvetica", + penwidth=2, + ) + rank_subgraph.add_node(node) + cluster.add_subgraph(rank_subgraph) + cluster.add_node(node) + + self._subgraphs[cluster_id] = cluster + + parent_graph_id = "__base__" if parent_graph_id is None else parent_graph_id + self._subgraphs[parent_graph_id].add_subgraph(cluster) + + def to_file(self, output_filename: str) -> None: + """Save the graph to a file. + + This method will infer the file's format (e.g., 'png', 'svg') from this filename's extension. + If no extension is provided, the 'png' format will be the default. + + Args: + output_filename (str): Desired filename for the graph. File extension can be included + and if no file extension is provided, it will default to a `.png` file. + + """ + output_filename_path: pathlib.Path = pathlib.Path(output_filename) + if not output_filename_path.suffix: + output_filename_path = output_filename_path.with_suffix(".png") + + format = output_filename_path.suffix[1:].lower() + + self.graph.write(str(output_filename_path), format=format) + + def to_string(self) -> str: + """Return the graph as a string. + + This is typically used to get the graph's representation in a standard string format like DOT. + + Returns: + str: A string representation of the graph. + """ + return self.graph.to_string() diff --git a/frontend/test/pytest/python_interface/visualization/test_dag_builder.py b/frontend/test/pytest/python_interface/visualization/test_dag_builder.py new file mode 100644 index 0000000000..2e935bae1b --- /dev/null +++ b/frontend/test/pytest/python_interface/visualization/test_dag_builder.py @@ -0,0 +1,96 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for the DAGBuilder abstract base class.""" + +from typing import Any + +import pytest + +pytestmark = pytest.mark.usefixtures("requires_xdsl") + +# pylint: disable=wrong-import-position +# This import needs to be after pytest in order to prevent ImportErrors +from catalyst.python_interface.visualization.dag_builder import DAGBuilder + + +def test_concrete_implementation_works(): + """Unit test for concrete implementation of abc.""" + + # pylint: disable=unused-argument + class ConcreteDAGBuilder(DAGBuilder): + """Concrete subclass of an ABC for testing purposes.""" + + def add_node( + self, + node_id: str, + node_label: str, + parent_graph_id: str | None = None, + **node_attrs: Any, + ) -> None: + return + + def add_edge( + self, from_node_id: str, to_node_id: str, **edge_attrs: Any + ) -> None: + return + + def add_cluster( + self, + cluster_id: str, + node_label: str | None = None, + parent_graph_id: str | None = None, + **cluster_attrs: Any, + ) -> None: + return + + def to_file(self, output_filename: str) -> None: + return + + def to_string(self) -> str: + return "test" + + dag_builder = ConcreteDAGBuilder() + # pylint: disable = assignment-from-none + node = dag_builder.add_node("0", "node0") + edge = dag_builder.add_edge("0", "1") + cluster = dag_builder.add_cluster("0") + render = dag_builder.to_file("test.png") + string = dag_builder.to_string() + + assert node is None + assert edge is None + assert cluster is None + assert render is None + assert string == "test" + + +def test_abc_cannot_be_instantiated(): + """Tests that the DAGBuilder ABC cannot be instantiated.""" + + with pytest.raises(TypeError, match="Can't instantiate abstract class"): + # pylint: disable=abstract-class-instantiated + DAGBuilder() + + +def test_incomplete_subclass(): + """Tests that an incomplete subclass will fail""" + + # pylint: disable=too-few-public-methods + class IncompleteDAGBuilder(DAGBuilder): + def add_node(self, *args, **kwargs): + pass + + with pytest.raises(TypeError, match="Can't instantiate abstract class"): + # pylint: disable=abstract-class-instantiated + IncompleteDAGBuilder() diff --git a/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py new file mode 100644 index 0000000000..11e5e4ac47 --- /dev/null +++ b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py @@ -0,0 +1,300 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for the PyDotDAGBuilder subclass.""" + +from unittest.mock import MagicMock + +import pytest + +pydot = pytest.importorskip("pydot") +pytestmark = pytest.mark.usefixtures("requires_xdsl") +# pylint: disable=wrong-import-position +from catalyst.python_interface.visualization.pydot_dag_builder import PyDotDAGBuilder + + +@pytest.mark.unit +def test_initialization_defaults(): + """Tests the default graph attributes are as expected.""" + + dag_builder = PyDotDAGBuilder() + + assert isinstance(dag_builder.graph, pydot.Dot) + # Ensure it's a directed graph + assert dag_builder.graph.get_graph_type() == "digraph" + # Ensure that it flows top to bottom + assert dag_builder.graph.get_rankdir() == "TB" + # Ensure edges can be connected directly to clusters / subgraphs + assert dag_builder.graph.get_compound() == "true" + # Ensure duplicated edges cannot be added + assert dag_builder.graph.obj_dict["strict"] is True + + +class TestAddMethods: + """Test that elements can be added to the graph.""" + + @pytest.mark.unit + def test_add_node(self): + """Unit test the `add_node` method.""" + + dag_builder = PyDotDAGBuilder() + + dag_builder.add_node("0", "node0") + node_list = dag_builder.graph.get_node_list() + assert len(node_list) == 1 + assert node_list[0].get_label() == "node0" + + @pytest.mark.unit + def test_add_edge(self): + """Unit test the `add_edge` method.""" + + dag_builder = PyDotDAGBuilder() + dag_builder.add_node("0", "node0") + dag_builder.add_node("1", "node1") + dag_builder.add_edge("0", "1") + + assert len(dag_builder.graph.get_edges()) == 1 + edge = dag_builder.graph.get_edges()[0] + assert edge.get_source() == "0" + assert edge.get_destination() == "1" + + @pytest.mark.unit + def test_add_cluster(self): + """Unit test the 'add_cluster' method.""" + + dag_builder = PyDotDAGBuilder() + dag_builder.add_cluster("0") + + assert len(dag_builder.graph.get_subgraphs()) == 1 + assert dag_builder.graph.get_subgraphs()[0].get_name() == "cluster_0" + + @pytest.mark.unit + def test_add_node_to_parent_graph(self): + """Tests that you can add a node to a parent graph.""" + dag_builder = PyDotDAGBuilder() + + # Create node + dag_builder.add_node("0", "node0") + + # Create cluster + dag_builder.add_cluster("c0") + + # Create node inside cluster + dag_builder.add_node("1", "node1", parent_graph_id="c0") + + # Verify graph structure + root_graph = dag_builder.graph + + # Make sure the base graph has node0 + assert root_graph.get_node("0"), "Node 0 not found in root graph" + + # Get the cluster and verify it has node1 and not node0 + cluster_list = root_graph.get_subgraph("cluster_c0") + assert cluster_list, "Subgraph 'cluster_c0' not found" + cluster_graph = cluster_list[0] # Get the actual subgraph object + + assert cluster_graph.get_node("1"), "Node 1 not found in cluster 'c0'" + assert not cluster_graph.get_node("0"), ( + "Node 0 was incorrectly added to cluster" + ) + + assert not root_graph.get_node("1"), "Node 1 was incorrectly added to root" + + @pytest.mark.unit + def test_add_cluster_to_parent_graph(self): + """Test that you can add a cluster to a parent graph.""" + dag_builder = PyDotDAGBuilder() + + # Level 0 (Root): Adds cluster on top of base graph + dag_builder.add_node("n_root", "node_root") + dag_builder.add_cluster("c0") + + # Level 1 (Inside c0): Add node on outer cluster and create new cluster on top + dag_builder.add_node("n_outer", "node_outer", parent_graph_id="c0") + dag_builder.add_cluster("c1", parent_graph_id="c0") + + # Level 2 (Inside c1): Add node on second cluster + dag_builder.add_node("n_inner", "node_inner", parent_graph_id="c1") + + root_graph = dag_builder.graph + + outer_cluster_list = root_graph.get_subgraph("cluster_c0") + assert outer_cluster_list, "Outer cluster 'c0' not found in root" + c0 = outer_cluster_list[0] + + inner_cluster_list = c0.get_subgraph("cluster_c1") + assert inner_cluster_list, "Inner cluster 'c1' not found in 'c0'" + c1 = inner_cluster_list[0] + + # Check Level 0 (Root) + assert root_graph.get_node("n_root"), "n_root not found in root" + assert root_graph.get_subgraph("cluster_c0"), "c0 not found in root" + assert not root_graph.get_node("n_outer"), "n_outer incorrectly found in root" + assert not root_graph.get_node("n_inner"), "n_inner incorrectly found in root" + assert not root_graph.get_subgraph("cluster_c1"), "c1 incorrectly found in root" + + # Check Level 1 (c0) + assert c0.get_node("n_outer"), "n_outer not found in c0" + assert c0.get_subgraph("cluster_c1"), "c1 not found in c0" + assert not c0.get_node("n_root"), "n_root incorrectly found in c0" + assert not c0.get_node("n_inner"), "n_inner incorrectly found in c0" + + # Check Level 2 (c1) + assert c1.get_node("n_inner"), "n_inner not found in c1" + assert not c1.get_node("n_root"), "n_root incorrectly found in c1" + assert not c1.get_node("n_outer"), "n_outer incorrectly found in c1" + + +class TestAttributes: + """Tests that the attributes for elements in the graph are overridden correctly.""" + + @pytest.mark.unit + def test_default_graph_attrs(self): + """Test that default graph attributes can be set.""" + + dag_builder = PyDotDAGBuilder(attrs={"fontname": "Times"}) + + dag_builder.add_node("0", "node0") + node0 = dag_builder.graph.get_node("0")[0] + assert node0.get("fontname") == "Times" + + dag_builder.add_cluster("1") + cluster = dag_builder.graph.get_subgraphs()[0] + assert cluster.get("fontname") == "Times" + + @pytest.mark.unit + def test_add_node_with_attrs(self): + """Tests that default attributes are applied and can be overridden.""" + dag_builder = PyDotDAGBuilder( + node_attrs={"fillcolor": "lightblue", "penwidth": 3} + ) + + # Defaults + dag_builder.add_node("0", "node0") + node0 = dag_builder.graph.get_node("0")[0] + assert node0.get("fillcolor") == "lightblue" + assert node0.get("penwidth") == 3 + + # Make sure we can override + dag_builder.add_node("1", "node1", fillcolor="red", penwidth=4) + node1 = dag_builder.graph.get_node("1")[0] + assert node1.get("fillcolor") == "red" + assert node1.get("penwidth") == 4 + + @pytest.mark.unit + def test_add_edge_with_attrs(self): + """Tests that default attributes are applied and can be overridden.""" + dag_builder = PyDotDAGBuilder(edge_attrs={"color": "lightblue4", "penwidth": 3}) + + dag_builder.add_node("0", "node0") + dag_builder.add_node("1", "node1") + dag_builder.add_edge("0", "1") + edge = dag_builder.graph.get_edges()[0] + # Defaults defined earlier + assert edge.get("color") == "lightblue4" + assert edge.get("penwidth") == 3 + + # Make sure we can override + dag_builder.add_edge("0", "1", color="red", penwidth=4) + edge = dag_builder.graph.get_edges()[1] + assert edge.get("color") == "red" + assert edge.get("penwidth") == 4 + + @pytest.mark.unit + def test_add_cluster_with_attrs(self): + """Tests that default cluster attributes are applied and can be overridden.""" + dag_builder = PyDotDAGBuilder( + cluster_attrs={ + "style": "solid", + "fillcolor": None, + "penwidth": 2, + "fontname": "Helvetica", + } + ) + + dag_builder.add_cluster("0") + cluster1 = dag_builder.graph.get_subgraph("cluster_0")[0] + + # Defaults + assert cluster1.get("style") == "solid" + assert cluster1.get("fillcolor") is None + assert cluster1.get("penwidth") == 2 + assert cluster1.get("fontname") == "Helvetica" + + dag_builder.add_cluster( + "1", style="filled", penwidth=10, fillcolor="red" + ) + cluster2 = dag_builder.graph.get_subgraph("cluster_1")[0] + + # Make sure we can override + assert cluster2.get("style") == "filled" + assert cluster2.get("penwidth") == 10 + assert cluster2.get("fillcolor") == "red" + + # Check that other defaults are still present + assert cluster2.get("fontname") == "Helvetica" + + +class TestOutput: + """Test that the graph can be outputted correctly.""" + + @pytest.mark.unit + @pytest.mark.parametrize( + "filename, format", + [("my_graph", None), ("my_graph", "png"), ("prototype.trial1", "png")], + ) + def test_to_file(self, monkeypatch, filename, format): + """Tests that the `to_file` method works correctly.""" + dag_builder = PyDotDAGBuilder() + + # mock out the graph writing functionality + mock_write = MagicMock() + monkeypatch.setattr(dag_builder.graph, "write", mock_write) + dag_builder.to_file(filename + "." + (format or "png")) + + # make sure the function handles extensions correctly + mock_write.assert_called_once_with( + filename + "." + (format or "png"), format=format or "png" + ) + + @pytest.mark.unit + @pytest.mark.parametrize("format", ["pdf", "svg", "jpeg"]) + def test_other_supported_formats(self, monkeypatch, format): + """Tests that the `to_file` method works with other formats.""" + dag_builder = PyDotDAGBuilder() + + # mock out the graph writing functionality + mock_write = MagicMock() + monkeypatch.setattr(dag_builder.graph, "write", mock_write) + dag_builder.to_file(f"my_graph.{format}") + + # make sure the function handles extensions correctly + mock_write.assert_called_once_with(f"my_graph.{format}", format=format) + + @pytest.mark.unit + def test_to_string(self): + """Tests that the `to_string` method works correclty.""" + + dag_builder = PyDotDAGBuilder() + dag_builder.add_node("n0", "node0") + dag_builder.add_node("n1", "node1") + dag_builder.add_edge("n0", "n1") + + string = dag_builder.to_string() + assert isinstance(string, str) + + # make sure important things show up in the string + assert "digraph" in string + assert "n0" in string + assert "n1" in string + assert "n0 -> n1" in string diff --git a/requirements.txt b/requirements.txt index 08f57ee9c0..6b904318c6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -36,3 +36,4 @@ pennylane-lightning-kokkos amazon-braket-pennylane-plugin>1.27.1 xdsl xdsl-jax +pydot From 09be3af01001e67d94167038b933f1980a517eec Mon Sep 17 00:00:00 2001 From: Andrija Paurevic <46359773+andrijapau@users.noreply.github.com> Date: Wed, 26 Nov 2025 15:56:11 -0500 Subject: [PATCH 4/4] feat: add `nodes`, `clusters` and `edges` properties to DAGBuilders (#2229) Main goal of this PR was to add properties to help probe the inner structure of the DAG. This PR also does, - minor clean-up w.r.t naming and formatting - adds validation to add_* methods [sc-103457] [sc-103456] --------- Co-authored-by: Mudit Pandey <18223836+mudit2812@users.noreply.github.com> Co-authored-by: Mehrdad Malek <39844030+mehrdad2m@users.noreply.github.com> --- doc/releases/changelog-dev.md | 1 + .../visualization/dag_builder.py | 71 ++++++-- .../visualization/pydot_dag_builder.py | 165 +++++++++++++----- .../visualization/test_dag_builder.py | 36 ++-- .../visualization/test_pydot_dag_builder.py | 142 +++++++++++++-- 5 files changed, 335 insertions(+), 80 deletions(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index ee7c22c7e3..6e8a6465ff 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -4,6 +4,7 @@ * Compiled programs can be visualized. [(#2213)](https://github.com/PennyLaneAI/catalyst/pull/2213) + [(#2229)](https://github.com/PennyLaneAI/catalyst/pull/2229) * Added ``catalyst.switch``, a qjit compatible, index-switch style control flow decorator. [(#2171)](https://github.com/PennyLaneAI/catalyst/pull/2171) diff --git a/frontend/catalyst/python_interface/visualization/dag_builder.py b/frontend/catalyst/python_interface/visualization/dag_builder.py index aeb8b217b6..70c5806616 100644 --- a/frontend/catalyst/python_interface/visualization/dag_builder.py +++ b/frontend/catalyst/python_interface/visualization/dag_builder.py @@ -14,7 +14,10 @@ """File that defines the DAGBuilder abstract base class.""" from abc import ABC, abstractmethod -from typing import Any +from typing import Any, TypeAlias + +ClusterID: TypeAlias = str +NodeID: TypeAlias = str class DAGBuilder(ABC): @@ -28,27 +31,32 @@ class DAGBuilder(ABC): @abstractmethod def add_node( - self, node_id: str, node_label: str, parent_graph_id: str | None = None, **node_attrs: Any + self, + id: NodeID, + label: str, + cluster_id: ClusterID | None = None, + **attrs: Any, ) -> None: """Add a single node to the graph. Args: - node_id (str): Unique node ID to identify this node. - node_label (str): The text to display on the node when rendered. - parent_graph_id (str | None): Optional ID of the cluster this node belongs to. - **node_attrs (Any): Any additional styling keyword arguments. + id (str): Unique node ID to identify this node. + label (str): The text to display on the node when rendered. + cluster_id (str | None): Optional ID of the cluster this node belongs to. If `None`, this node gets + added on the base graph. + **attrs (Any): Any additional styling keyword arguments. """ raise NotImplementedError @abstractmethod - def add_edge(self, from_node_id: str, to_node_id: str, **edge_attrs: Any) -> None: + def add_edge(self, from_id: NodeID, to_id: NodeID, **attrs: Any) -> None: """Add a single directed edge between nodes in the graph. Args: - from_node_id (str): The unique ID of the source node. - to_node_id (str): The unique ID of the destination node. - **edge_attrs (Any): Any additional styling keyword arguments. + from_id (str): The unique ID of the source node. + to_id (str): The unique ID of the destination node. + **attrs (Any): Any additional styling keyword arguments. """ raise NotImplementedError @@ -56,10 +64,10 @@ def add_edge(self, from_node_id: str, to_node_id: str, **edge_attrs: Any) -> Non @abstractmethod def add_cluster( self, - cluster_id: str, + id: ClusterID, node_label: str | None = None, - parent_graph_id: str | None = None, - **cluster_attrs: Any, + cluster_id: ClusterID | None = None, + **attrs: Any, ) -> None: """Add a single cluster to the graph. @@ -67,11 +75,42 @@ def add_cluster( within it are visually and logically grouped. Args: - cluster_id (str): Unique cluster ID to identify this cluster. + id (str): Unique cluster ID to identify this cluster. node_label (str | None): The text to display on an information node within the cluster when rendered. - parent_graph_id (str | None): Optional ID of the cluster this cluster belongs to. - **cluster_attrs (Any): Any additional styling keyword arguments. + cluster_id (str | None): Optional ID of the cluster this cluster belongs to. If `None`, the cluster will be + placed on the base graph. + **attrs (Any): Any additional styling keyword arguments. + + """ + raise NotImplementedError + + @property + @abstractmethod + def nodes(self) -> dict[NodeID, dict[str, Any]]: + """Retrieve the current set of nodes in the graph. + + Returns: + nodes (dict[str, dict[str, Any]]): A dictionary that maps the node's ID to its node information. + """ + raise NotImplementedError + + @property + @abstractmethod + def edges(self) -> list[dict[str, Any]]: + """Retrieve the current set of edges in the graph. + + Returns: + edges (list[dict[str, Any]]): A list of edges where each element in the list contains a dictionary of edge information. + """ + raise NotImplementedError + @property + @abstractmethod + def clusters(self) -> dict[ClusterID, dict[str, Any]]: + """Retrieve the current set of clusters in the graph. + + Returns: + clusters (dict[str, dict[str, Any]]): A dictionary that maps the cluster's ID to its cluster information. """ raise NotImplementedError diff --git a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py index 384537c93e..ea01ecdf21 100644 --- a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py +++ b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py @@ -22,6 +22,7 @@ has_pydot = True try: import pydot + from pydot import Cluster, Dot, Edge, Graph, Node, Subgraph except ImportError: has_pydot = False @@ -50,14 +51,23 @@ def __init__( # - rankdir="TB": Set layout direction from Top to Bottom. # - compound="true": Allow edges to connect directly to clusters/subgraphs. # - strict=True: Prevent duplicate edges (e.g., A -> B added twice). - self.graph: pydot.Dot = pydot.Dot( + self.graph: Dot = Dot( graph_type="digraph", rankdir="TB", compound="true", strict=True ) - # Create cache for easy look-up - self._subgraphs: dict[str, pydot.Graph] = {} - self._subgraphs["__base__"] = self.graph - _default_attrs: dict = {"fontname": "Helvetica", "penwidth": 2} if attrs is None else attrs + # Use internal cache that maps cluster ID to actual pydot (Dot or Cluster) object + # NOTE: This is needed so we don't need to traverse the graph to find the relevant + # cluster object to modify + self._subgraph_cache: dict[str, Graph] = {} + + # Internal state for graph structure + self._nodes: dict[str, dict[str, Any]] = {} + self._edges: list[dict[str, Any]] = [] + self._clusters: dict[str, dict[str, Any]] = {} + + _default_attrs: dict = ( + {"fontname": "Helvetica", "penwidth": 2} if attrs is None else attrs + ) self._default_node_attrs: dict = ( { **_default_attrs, @@ -90,47 +100,80 @@ def __init__( def add_node( self, - node_id: str, - node_label: str, - parent_graph_id: str | None = None, - **node_attrs: Any, + id: str, + label: str, + cluster_id: str | None = None, + **attrs: Any, ) -> None: """Add a single node to the graph. Args: - node_id (str): Unique node ID to identify this node. - node_label (str): The text to display on the node when rendered. - parent_graph_id (str | None): Optional ID of the cluster this node belongs to. - **node_attrs (Any): Any additional styling keyword arguments. + id (str): Unique node ID to identify this node. + label (str): The text to display on the node when rendered. + cluster_id (str | None): Optional ID of the cluster this node belongs to. + **attrs (Any): Any additional styling keyword arguments. - """ - # Use ChainMap so you don't need to construct a new dictionary - node_attrs = ChainMap(node_attrs, self._default_node_attrs) - node = pydot.Node(node_id, label=node_label, **node_attrs) - parent_graph_id = "__base__" if parent_graph_id is None else parent_graph_id + Raises: + ValueError: Node ID is already present in the graph. - self._subgraphs[parent_graph_id].add_node(node) + """ + if id in self.nodes: + raise ValueError(f"Node ID {id} already present in graph.") - def add_edge(self, from_node_id: str, to_node_id: str, **edge_attrs: Any) -> None: + # Use ChainMap so you don't need to construct a new dictionary + node_attrs: ChainMap = ChainMap(attrs, self._default_node_attrs) + node = Node(id, label=label, **node_attrs) + + # Add node to cluster + if cluster_id is None: + self.graph.add_node(node) + else: + self._subgraph_cache[cluster_id].add_node(node) + + self._nodes[id] = { + "id": id, + "label": label, + "cluster_id": cluster_id, + "attrs": dict(node_attrs), + } + + def add_edge(self, from_id: str, to_id: str, **attrs: Any) -> None: """Add a single directed edge between nodes in the graph. Args: - from_node_id (str): The unique ID of the source node. - to_node_id (str): The unique ID of the destination node. - **edge_attrs (Any): Any additional styling keyword arguments. + from_id (str): The unique ID of the source node. + to_id (str): The unique ID of the destination node. + **attrs (Any): Any additional styling keyword arguments. + + Raises: + ValueError: Source and destination have the same ID + ValueError: Source is not found in the graph. + ValueError: Destination is not found in the graph. """ + if from_id == to_id: + raise ValueError("Edges must connect two unique IDs.") + if from_id not in self.nodes: + raise ValueError("Source is not found in the graph.") + if to_id not in self.nodes: + raise ValueError("Destination is not found in the graph.") + # Use ChainMap so you don't need to construct a new dictionary - edge_attrs = ChainMap(edge_attrs, self._default_edge_attrs) - edge = pydot.Edge(from_node_id, to_node_id, **edge_attrs) + edge_attrs: ChainMap = ChainMap(attrs, self._default_edge_attrs) + edge = Edge(from_id, to_id, **edge_attrs) + self.graph.add_edge(edge) + self._edges.append( + {"from_id": from_id, "to_id": to_id, "attrs": dict(edge_attrs)} + ) + def add_cluster( self, - cluster_id: str, + id: str, node_label: str | None = None, - parent_graph_id: str | None = None, - **cluster_attrs: Any, + cluster_id: str | None = None, + **attrs: Any, ) -> None: """Add a single cluster to the graph. @@ -138,15 +181,20 @@ def add_cluster( within it are visually and logically grouped. Args: - cluster_id (str): Unique cluster ID to identify this cluster. + id (str): Unique cluster ID to identify this cluster. node_label (str | None): The text to display on the information node within the cluster when rendered. - parent_graph_id (str | None): Optional ID of the cluster this cluster belongs to. - **cluster_attrs (Any): Any additional styling keyword arguments. + cluster_id (str | None): Optional ID of the cluster this cluster belongs to. If `None`, the cluster will be positioned on the base graph. + **attrs (Any): Any additional styling keyword arguments. + Raises: + ValueError: Cluster ID is already present in the graph. """ + if id in self.clusters: + raise ValueError(f"Cluster ID {id} already present in graph.") + # Use ChainMap so you don't need to construct a new dictionary - cluster_attrs = ChainMap(cluster_attrs, self._default_cluster_attrs) - cluster = pydot.Cluster(graph_name=cluster_id, **cluster_attrs) + cluster_attrs: ChainMap = ChainMap(attrs, self._default_cluster_attrs) + cluster = Cluster(id, **cluster_attrs) # Puts the label in a node within the cluster. # Ensures that any edges connecting nodes through the cluster @@ -158,9 +206,9 @@ def add_cluster( # │ │ # └───────────┘ if node_label: - node_id = f"{cluster_id}_info_node" - rank_subgraph = pydot.Subgraph() - node = pydot.Node( + node_id = f"{id}_info_node" + rank_subgraph = Subgraph() + node = Node( node_id, label=node_label, shape="rectangle", @@ -172,10 +220,49 @@ def add_cluster( cluster.add_subgraph(rank_subgraph) cluster.add_node(node) - self._subgraphs[cluster_id] = cluster + # Record new cluster + self._subgraph_cache[id] = cluster - parent_graph_id = "__base__" if parent_graph_id is None else parent_graph_id - self._subgraphs[parent_graph_id].add_subgraph(cluster) + # Add node to cluster + if cluster_id is None: + self.graph.add_subgraph(cluster) + else: + self._subgraph_cache[cluster_id].add_subgraph(cluster) + + self._clusters[id] = { + "id": id, + "cluster_label": cluster_attrs.get("label"), + "node_label": node_label, + "cluster_id": cluster_id, + "attrs": dict(cluster_attrs), + } + + @property + def nodes(self) -> dict[str, dict[str, Any]]: + """Retrieve the current set of nodes in the graph. + + Returns: + nodes (dict[str, dict[str, Any]]): A dictionary that maps the node's ID to its node information. + """ + return self._nodes + + @property + def edges(self) -> list[dict[str, Any]]: + """Retrieve the current set of edges in the graph. + + Returns: + edges (list[dict[str, Any]]): A list of edges where each element in the list contains a dictionary of edge information. + """ + return self._edges + + @property + def clusters(self) -> dict[str, dict[str, Any]]: + """Retrieve the current set of clusters in the graph. + + Returns: + clusters (dict[str, dict[str, Any]]): A dictionary that maps the cluster's ID to its cluster information. + """ + return self._clusters def to_file(self, output_filename: str) -> None: """Save the graph to a file. diff --git a/frontend/test/pytest/python_interface/visualization/test_dag_builder.py b/frontend/test/pytest/python_interface/visualization/test_dag_builder.py index 2e935bae1b..df3a431ae2 100644 --- a/frontend/test/pytest/python_interface/visualization/test_dag_builder.py +++ b/frontend/test/pytest/python_interface/visualization/test_dag_builder.py @@ -33,27 +33,37 @@ class ConcreteDAGBuilder(DAGBuilder): def add_node( self, - node_id: str, - node_label: str, - parent_graph_id: str | None = None, - **node_attrs: Any, + id: str, + label: str, + cluster_id: str | None = None, + **attrs: Any, ) -> None: return - def add_edge( - self, from_node_id: str, to_node_id: str, **edge_attrs: Any - ) -> None: + def add_edge(self, from_id: str, to_id: str, **attrs: Any) -> None: return def add_cluster( self, - cluster_id: str, + id: str, node_label: str | None = None, - parent_graph_id: str | None = None, - **cluster_attrs: Any, + cluster_id: str | None = None, + **attrs: Any, ) -> None: return + @property + def nodes(self) -> dict[str, dict[str, Any]]: + return {} + + @property + def edges(self) -> list[dict[str, Any]]: + return [] + + @property + def clusters(self) -> dict[str, dict[str, Any]]: + return {} + def to_file(self, output_filename: str) -> None: return @@ -65,12 +75,18 @@ def to_string(self) -> str: node = dag_builder.add_node("0", "node0") edge = dag_builder.add_edge("0", "1") cluster = dag_builder.add_cluster("0") + nodes = dag_builder.nodes + edges = dag_builder.edges + clusters = dag_builder.clusters render = dag_builder.to_file("test.png") string = dag_builder.to_string() assert node is None + assert nodes == {} assert edge is None + assert edges == [] assert cluster is None + assert clusters == {} assert render is None assert string == "test" diff --git a/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py index 11e5e4ac47..d9a17f5e3b 100644 --- a/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py +++ b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py @@ -40,6 +40,53 @@ def test_initialization_defaults(): assert dag_builder.graph.obj_dict["strict"] is True +class TestExceptions: + """Tests the various exceptions defined in the class.""" + + def test_duplicate_node_ids(self): + """Tests that a ValueError is raised for duplicate nodes.""" + + dag_builder = PyDotDAGBuilder() + + dag_builder.add_node("0", "node0") + with pytest.raises(ValueError, match="Node ID 0 already present in graph."): + dag_builder.add_node("0", "node1") + + def test_edge_duplicate_source_destination(self): + """Tests that a ValueError is raised when an edge is created with the + same source and destination""" + + dag_builder = PyDotDAGBuilder() + + dag_builder.add_node("0", "node0") + with pytest.raises(ValueError, match="Edges must connect two unique IDs."): + dag_builder.add_edge("0", "0") + + def test_edge_missing_ids(self): + """Tests that an error is raised if IDs are missing.""" + + dag_builder = PyDotDAGBuilder() + + dag_builder.add_node("0", "node0") + with pytest.raises(ValueError, match="Destination is not found in the graph."): + dag_builder.add_edge("0", "1") + + dag_builder = PyDotDAGBuilder() + + dag_builder.add_node("1", "node1") + with pytest.raises(ValueError, match="Source is not found in the graph."): + dag_builder.add_edge("0", "1") + + def test_duplicate_cluster_id(self): + """Tests that an exception is raised if an ID is already present.""" + + dag_builder = PyDotDAGBuilder() + + dag_builder.add_cluster("0") + with pytest.raises(ValueError, match="Cluster ID 0 already present in graph."): + dag_builder.add_cluster("0") + + class TestAddMethods: """Test that elements can be added to the graph.""" @@ -90,7 +137,7 @@ def test_add_node_to_parent_graph(self): dag_builder.add_cluster("c0") # Create node inside cluster - dag_builder.add_node("1", "node1", parent_graph_id="c0") + dag_builder.add_node("1", "node1", cluster_id="c0") # Verify graph structure root_graph = dag_builder.graph @@ -117,14 +164,14 @@ def test_add_cluster_to_parent_graph(self): # Level 0 (Root): Adds cluster on top of base graph dag_builder.add_node("n_root", "node_root") - dag_builder.add_cluster("c0") - # Level 1 (Inside c0): Add node on outer cluster and create new cluster on top - dag_builder.add_node("n_outer", "node_outer", parent_graph_id="c0") - dag_builder.add_cluster("c1", parent_graph_id="c0") + # Level 1 (c0): Add node on outer cluster + dag_builder.add_cluster("c0") + dag_builder.add_node("n_outer", "node_outer", cluster_id="c0") - # Level 2 (Inside c1): Add node on second cluster - dag_builder.add_node("n_inner", "node_inner", parent_graph_id="c1") + # Level 2 (c1): Add node on inner cluster + dag_builder.add_cluster("c1", cluster_id="c0") + dag_builder.add_node("n_inner", "node_inner", cluster_id="c1") root_graph = dag_builder.graph @@ -175,9 +222,7 @@ def test_default_graph_attrs(self): @pytest.mark.unit def test_add_node_with_attrs(self): """Tests that default attributes are applied and can be overridden.""" - dag_builder = PyDotDAGBuilder( - node_attrs={"fillcolor": "lightblue", "penwidth": 3} - ) + dag_builder = PyDotDAGBuilder(attrs={"fillcolor": "lightblue", "penwidth": 3}) # Defaults dag_builder.add_node("0", "node0") @@ -194,7 +239,7 @@ def test_add_node_with_attrs(self): @pytest.mark.unit def test_add_edge_with_attrs(self): """Tests that default attributes are applied and can be overridden.""" - dag_builder = PyDotDAGBuilder(edge_attrs={"color": "lightblue4", "penwidth": 3}) + dag_builder = PyDotDAGBuilder(attrs={"color": "lightblue4", "penwidth": 3}) dag_builder.add_node("0", "node0") dag_builder.add_node("1", "node1") @@ -214,7 +259,7 @@ def test_add_edge_with_attrs(self): def test_add_cluster_with_attrs(self): """Tests that default cluster attributes are applied and can be overridden.""" dag_builder = PyDotDAGBuilder( - cluster_attrs={ + attrs={ "style": "solid", "fillcolor": None, "penwidth": 2, @@ -231,9 +276,7 @@ def test_add_cluster_with_attrs(self): assert cluster1.get("penwidth") == 2 assert cluster1.get("fontname") == "Helvetica" - dag_builder.add_cluster( - "1", style="filled", penwidth=10, fillcolor="red" - ) + dag_builder.add_cluster("1", style="filled", penwidth=10, fillcolor="red") cluster2 = dag_builder.graph.get_subgraph("cluster_1")[0] # Make sure we can override @@ -245,6 +288,75 @@ def test_add_cluster_with_attrs(self): assert cluster2.get("fontname") == "Helvetica" +class TestProperties: + """Tests the properties.""" + + def test_nodes(self): + """Tests that nodes works.""" + dag_builder = PyDotDAGBuilder() + + dag_builder.add_node("0", "node0", fillcolor="red") + dag_builder.add_cluster("c0") + dag_builder.add_node("1", "node1", cluster_id="c0") + + nodes = dag_builder.nodes + + assert len(nodes) == 2 + assert len(nodes["0"]) == 4 + + assert nodes["0"]["id"] == "0" + assert nodes["0"]["label"] == "node0" + assert nodes["0"]["cluster_id"] == None + assert nodes["0"]["attrs"]["fillcolor"] == "red" + + assert nodes["1"]["id"] == "1" + assert nodes["1"]["label"] == "node1" + assert nodes["1"]["cluster_id"] == "c0" + + def test_edges(self): + """Tests that edges works.""" + + dag_builder = PyDotDAGBuilder() + dag_builder.add_node("0", "node0") + dag_builder.add_node("1", "node1") + dag_builder.add_edge("0", "1", penwidth=10) + + edges = dag_builder.edges + + assert len(edges) == 1 + + assert edges[0]["from_id"] == "0" + assert edges[0]["to_id"] == "1" + assert edges[0]["attrs"]["penwidth"] == 10 + + def test_clusters(self): + """Tests that clusters property works.""" + + dag_builder = PyDotDAGBuilder() + dag_builder.add_cluster("0", "my_info_node", label="my_cluster", penwidth=10) + + clusters = dag_builder.clusters + + dag_builder.add_cluster( + "1", "my_other_info_node", cluster_id="0", label="my_nested_cluster" + ) + clusters = dag_builder.clusters + assert len(clusters) == 2 + + assert len(clusters["0"]) == 5 + assert clusters["0"]["id"] == "0" + assert clusters["0"]["cluster_label"] == "my_cluster" + assert clusters["0"]["node_label"] == "my_info_node" + assert clusters["0"]["cluster_id"] == None + assert clusters["0"]["attrs"]["penwidth"] == 10 + + assert len(clusters["1"]) == 5 + assert clusters["1"]["id"] == "1" + assert clusters["1"]["cluster_label"] == "my_nested_cluster" + assert clusters["1"]["node_label"] == "my_other_info_node" + assert clusters["1"]["cluster_id"] == "0" + + class TestOutput: """Test that the graph can be outputted correctly."""