diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 6e8a6465ff..06f1966f41 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -5,6 +5,7 @@ * Compiled programs can be visualized. [(#2213)](https://github.com/PennyLaneAI/catalyst/pull/2213) [(#2229)](https://github.com/PennyLaneAI/catalyst/pull/2229) + [(#2214)](https://github.com/PennyLaneAI/catalyst/pull/2214) * Added ``catalyst.switch``, a qjit compatible, index-switch style control flow decorator. [(#2171)](https://github.com/PennyLaneAI/catalyst/pull/2171) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py new file mode 100644 index 0000000000..408d3d785e --- /dev/null +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -0,0 +1,71 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Contains the ConstructCircuitDAG tool for constructing a DAG from an xDSL module.""" + +from functools import singledispatchmethod + +from xdsl.dialects import builtin +from xdsl.ir import Block, Operation, Region + +from catalyst.python_interface.visualization.dag_builder import DAGBuilder + + +class ConstructCircuitDAG: + """Utility tool following the director pattern to build a DAG representation of a compiled quantum program. + + This tool traverses an xDSL module and constructs a Directed Acyclic Graph (DAG) + of it's quantum program using an injected DAGBuilder instance. This tool does not mutate the xDSL module. + + **Example** + + >>> builder = PyDotDAGBuilder() + >>> director = ConstructCircuitDAG(builder) + >>> director.construct(module) + >>> director.dag_builder.to_string() + ... + """ + + def __init__(self, dag_builder: DAGBuilder) -> None: + self.dag_builder: DAGBuilder = dag_builder + + def construct(self, module: builtin.ModuleOp) -> None: + """Constructs the DAG from the module. + + Args: + module (xdsl.builtin.ModuleOp): The module containing the quantum program to visualize. + + """ + for op in module.ops: + self._visit_operation(op) + + # ============= + # IR TRAVERSAL + # ============= + + @singledispatchmethod + def _visit_operation(self, operation: Operation) -> None: + """Visit an xDSL Operation. Default to visiting each region contained in the operation.""" + for region in operation.regions: + self._visit_region(region) + + def _visit_region(self, region: Region) -> None: + """Visit an xDSL Region operation.""" + for block in region.blocks: + self._visit_block(block) + + def _visit_block(self, block: Block) -> None: + """Visit an xDSL Block operation, dispatching handling for each contained Operation.""" + for op in block.ops: + self._visit_operation(op) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py new file mode 100644 index 0000000000..52ea0d27e6 --- /dev/null +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -0,0 +1,126 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for the ConstructCircuitDAG utility.""" + +from unittest.mock import Mock + +import pytest + +pytestmark = pytest.mark.usefixtures("requires_xdsl") + +from xdsl.dialects import test +from xdsl.dialects.builtin import ModuleOp +from xdsl.ir.core import Block, Region + +# pylint: disable=wrong-import-position +# This import needs to be after pytest in order to prevent ImportErrors +from catalyst.python_interface.visualization.construct_circuit_dag import ( + ConstructCircuitDAG, +) +from catalyst.python_interface.visualization.dag_builder import DAGBuilder + + +class FakeDAGBuilder(DAGBuilder): + """ + A concrete implementation of DAGBuilder used ONLY for testing. + It stores all graph manipulation calls in data structures + for easy assertion of the final graph state. + """ + + def __init__(self): + self._nodes = {} + self._edges = [] + self._clusters = {} + + def add_node(self, id, label, cluster_id=None, **attrs) -> None: + self._nodes[id] = { + "id": id, + "label": label, + "parent_cluster_id": cluster_id, + "attrs": attrs, + } + + def add_edge(self, from_id: str, to_id: str, **attrs) -> None: + self._edges.append( + { + "from_id": from_id, + "to_id": to_id, + "attrs": attrs, + } + ) + + def add_cluster( + self, + id, + node_label=None, + cluster_id=None, + **attrs, + ) -> None: + self._clusters[id] = { + "id": id, + "node_label": node_label, + "cluster_label": attrs.get("label"), + "parent_cluster_id": cluster_id, + "attrs": attrs, + } + + @property + def nodes(self): + return self._nodes + + @property + def edges(self): + return self._edges + + @property + def clusters(self): + return self._clusters + + def to_file(self, output_filename): + pass + + def to_string(self) -> str: + return "graph" + + +@pytest.mark.unit +def test_dependency_injection(): + """Tests that relevant dependencies are injected.""" + + mock_dag_builder = Mock(DAGBuilder) + utility = ConstructCircuitDAG(mock_dag_builder) + assert utility.dag_builder is mock_dag_builder + + +@pytest.mark.unit +def test_does_not_mutate_module(): + """Test that the module is not mutated.""" + + # Create module + op = test.TestOp() + block = Block(ops=[op]) + region = Region(blocks=[block]) + container_op = test.TestOp(regions=[region]) + module_op = ModuleOp(ops=[container_op]) + + # Save state before + module_op_str_before = str(module_op) + + # Process module + mock_dag_builder = Mock(DAGBuilder) + utility = ConstructCircuitDAG(mock_dag_builder) + utility.construct(module_op) + + # Ensure not mutated + assert str(module_op) == module_op_str_before