Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/unroll_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: disable=invalid-name
# pylint: disable=invalid-name, cyclic-import

"""
Script demonstrating how to unroll a QASM 3 program using pyqasm.
Expand Down
77 changes: 77 additions & 0 deletions src/pyqasm/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from __future__ import annotations

import functools
import re
from abc import ABC, abstractmethod
from collections import Counter
from copy import deepcopy
Expand Down Expand Up @@ -761,3 +762,79 @@ def accept(self, visitor):
Args:
visitor (QasmVisitor): The visitor to accept
"""

@abstractmethod
def merge(
self,
other: "QasmModule",
device_qubits: Optional[int] = None,
) -> "QasmModule":
"""Merge this module with another module.

Implemented by concrete subclasses to avoid version mixing and
import-time cycles. Implementations should ensure both operands
are normalized to the same version prior to merging.
"""

@staticmethod
def offset_statement_qubits(
stmt: qasm3_ast.Statement, offset: int
): # pylint: disable=too-many-branches
"""Offset qubit indices for a given statement in-place by ``offset``.
Handles gates, measurements, resets, and barriers (including slice forms).
"""
if isinstance(stmt, qasm3_ast.QuantumMeasurementStatement):
bit = stmt.measure.qubit
if isinstance(bit, qasm3_ast.IndexedIdentifier):
for group in bit.indices:
for ind in group:
ind.value += offset # type: ignore[attr-defined]
return

if isinstance(stmt, qasm3_ast.QuantumGate):
for q in stmt.qubits:
for group in q.indices:
for ind in group:
ind.value += offset # type: ignore[attr-defined]
return

if isinstance(stmt, qasm3_ast.QuantumReset):
q = stmt.qubits
if isinstance(q, qasm3_ast.IndexedIdentifier):
for group in q.indices:
for ind in group:
ind.value += offset # type: ignore[attr-defined]
return

if isinstance(stmt, qasm3_ast.QuantumBarrier):
qubits = stmt.qubits
if len(qubits) == 0:
return
first = qubits[0]
if isinstance(first, qasm3_ast.IndexedIdentifier):
for group in first.indices:
for ind in group:
ind.value += offset # type: ignore[attr-defined]
elif isinstance(first, qasm3_ast.Identifier):
# Handle forms: __PYQASM_QUBITS__[:E], [S:], [S:E]
name = first.name
if name.startswith("__PYQASM_QUBITS__[") and name.endswith("]"):
slice_str = name[len("__PYQASM_QUBITS__") :]
# Parse slice forms [S:E], [:E], or [S:]
m = re.match(r"\[(?:(\d+)?:(\d+)?)\]", slice_str)
if m:
start_s, end_s = m.group(1), m.group(2)
if start_s is None and end_s is not None:
end_v = int(end_s) + offset
first.name = f"__PYQASM_QUBITS__[:{end_v}]"
elif start_s is not None and end_s is None:
start_v = int(start_s) + offset
first.name = f"__PYQASM_QUBITS__[{start_v}:]"
elif start_s is not None and end_s is not None:
start_v = int(start_s) + offset
end_v = int(end_s) + offset
first.name = f"__PYQASM_QUBITS__[{start_v}:{end_v}]"

def offset_statement_qubits(stmt: qasm3_ast.Statement, offset: int):
"""Backward-compat wrapper to the class staticmethod."""
return QasmModule.offset_statement_qubits(stmt, offset)
74 changes: 74 additions & 0 deletions src/pyqasm/modules/qasm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,77 @@ def accept(self, visitor):
final_stmt_list = visitor.finalize(unrolled_stmt_list)

self.unrolled_ast.statements = final_stmt_list

def merge(self, other: QasmModule, device_qubits: int | None = None) -> QasmModule:
"""Merge two modules and return a QASM2 result without mixing versions.
- If ``other`` is QASM3, it is merged into this module's semantics, and
any standard gate includes are mapped to ``qelib1.inc``.
- The merged program keeps version "2.0" and prints as QASM2.
"""
if not isinstance(other, QasmModule):
raise TypeError(f"Expected QasmModule instance, got {type(other).__name__}")

left_mod = self.copy()
right_mod = other.copy()

# Unroll with qubit consolidation so both sides use __PYQASM_QUBITS__
unroll_kwargs: dict[str, object] = {"consolidate_qubits": True}
if device_qubits is not None:
unroll_kwargs["device_qubits"] = device_qubits

left_mod.unroll(**unroll_kwargs)
right_mod.unroll(**unroll_kwargs)

left_qubits = left_mod.num_qubits

merged_program = Program(statements=[], version="2.0")

# Unique includes first; map stdgates.inc -> qelib1.inc for QASM2
include_names: set[str] = set()
for module in (left_mod, right_mod):
for stmt in module.unrolled_ast.statements:
if isinstance(stmt, Include):
fname = stmt.filename
if fname == "stdgates.inc":
fname = "qelib1.inc"
include_names.add(fname)
for fname in include_names:
merged_program.statements.append(Include(filename=fname))

# Consolidated qubit declaration (converted to qreg on print)
merged_program.statements.append(
qasm3_ast.QubitDeclaration(
size=qasm3_ast.IntegerLiteral(value=left_qubits + right_mod.num_qubits),
qubit=qasm3_ast.Identifier(name="__PYQASM_QUBITS__"),
)
)

# Append left ops (skip decls and includes)
for stmt in left_mod.unrolled_ast.statements:
if isinstance(stmt, (qasm3_ast.QubitDeclaration, Include)):
continue
merged_program.statements.append(deepcopy(stmt))

# Append right ops with index offset
for stmt in right_mod.unrolled_ast.statements:
if isinstance(stmt, (qasm3_ast.QubitDeclaration, Include)):
continue
stmt = deepcopy(stmt)
QasmModule.offset_statement_qubits(stmt, left_qubits)
merged_program.statements.append(stmt)

merged_module = Qasm2Module(
name=f"{left_mod.name}_merged_{right_mod.name}",
program=merged_program,
)
merged_module.unrolled_ast = Program(
statements=list(merged_program.statements),
version="2.0",
)
merged_module._external_gates = list(
{*left_mod._external_gates, *right_mod._external_gates}
)
merged_module._user_operations = list(left_mod.history) + list(right_mod.history)
merged_module._user_operations.append(f"merge(other={right_mod.name})")
merged_module.validate()
return merged_module
76 changes: 76 additions & 0 deletions src/pyqasm/modules/qasm3.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Defines a module for handling OpenQASM 3.0 programs.
"""

import openqasm3.ast as qasm3_ast
from openqasm3.ast import Program
from openqasm3.printer import dumps

Expand Down Expand Up @@ -52,3 +53,78 @@ def accept(self, visitor):
final_stmt_list = visitor.finalize(unrolled_stmt_list)

self._unrolled_ast.statements = final_stmt_list

def merge(self, other: QasmModule, device_qubits: int | None = None) -> QasmModule:
"""Merge two modules as OpenQASM 3.0 without mixing versions.
If ``other`` is QASM2, it will be converted to QASM3 before merging.
The merged program keeps version "3.0".
"""
if not isinstance(other, QasmModule):
raise TypeError(f"Expected QasmModule instance, got {type(other).__name__}")

# Convert right to QASM3 if it supports conversion; otherwise copy
right_mod = (
other.to_qasm3() if hasattr(other, "to_qasm3") else other.copy()
) # type: ignore[assignment]

left_mod = self.copy()

# Unroll with consolidation so both use __PYQASM_QUBITS__
unroll_kwargs: dict[str, object] = {"consolidate_qubits": True}
if device_qubits is not None:
unroll_kwargs["device_qubits"] = device_qubits

left_mod.unroll(**unroll_kwargs)
right_mod.unroll(**unroll_kwargs)

left_qubits = left_mod.num_qubits
total_qubits = left_qubits + right_mod.num_qubits

merged_program = Program(statements=[], version="3.0")

# Unique includes first
include_names: set[str] = set()
for module in (left_mod, right_mod):
for stmt in module.unrolled_ast.statements:
if isinstance(stmt, qasm3_ast.Include):
include_names.add(stmt.filename)
for fname in include_names:
merged_program.statements.append(qasm3_ast.Include(filename=fname))

# Consolidated qubit declaration
merged_program.statements.append(
qasm3_ast.QubitDeclaration(
size=qasm3_ast.IntegerLiteral(value=total_qubits),
qubit=qasm3_ast.Identifier(name="__PYQASM_QUBITS__"),
)
)

# Append left ops
for stmt in left_mod.unrolled_ast.statements:
if isinstance(stmt, (qasm3_ast.QubitDeclaration, qasm3_ast.Include)):
continue
merged_program.statements.append(stmt)

# Append right ops with index offset
for stmt in right_mod.unrolled_ast.statements:
if isinstance(stmt, (qasm3_ast.QubitDeclaration, qasm3_ast.Include)):
continue
# right_mod is a copy, so it's safe to modify statements in place
QasmModule.offset_statement_qubits(stmt, left_qubits)
merged_program.statements.append(stmt)

merged_module = Qasm3Module(
name=f"{left_mod.name}_merged_{right_mod.name}",
program=merged_program,
)
merged_module.unrolled_ast = Program(
statements=list(merged_program.statements),
version="3.0",
)
merged_module._external_gates = list(
{*left_mod._external_gates, *right_mod._external_gates}
)
merged_module._user_operations = list(left_mod.history) + list(right_mod.history)
merged_module._user_operations.append(f"merge(other={right_mod.name})")
merged_module.validate()
return merged_module
126 changes: 126 additions & 0 deletions tests/qasm3/test_merge.py
Copy link
Member

@TheGupta2012 TheGupta2012 Sep 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use the verification functions from the tests/utils.py file? We have already implemented a lot of robust checks in that file which you are implementing from scratch

Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright 2025 qBraid
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Unit tests for QasmModule.merge().
"""

from pyqasm.entrypoint import loads
from pyqasm.modules import QasmModule


def _qasm3(qasm: str) -> QasmModule:
return loads(qasm)


def test_merge_basic_gates_and_offsets():
qasm_a = (
"""
OPENQASM 3.0;
include "stdgates.inc";
qubit[2] q;
x q[0];
cx q[0], q[1];
"""
)
qasm_b = (
"""
OPENQASM 3.0;
include "stdgates.inc";
qubit[3] r;
h r[0];
cx r[1], r[2];
"""
)

mod_a = _qasm3(qasm_a)
mod_b = _qasm3(qasm_b)

merged = mod_a.merge(mod_b)

# Unrolled representation should have a single consolidated qubit declaration of size 5
text = str(merged)
assert "qubit[5] __PYQASM_QUBITS__;" in text

lines = [l.strip() for l in text.splitlines() if l.strip()]
# Keep only gate lines for comparison; skip version/includes/declarations
gate_lines = [
l
for l in lines
if l[0].isalpha()
and not l.startswith("include")
and not l.startswith("OPENQASM")
and not l.startswith("qubit")
]
assert gate_lines[0].startswith("x __PYQASM_QUBITS__[0]")
assert gate_lines[1].startswith("cx __PYQASM_QUBITS__[0], __PYQASM_QUBITS__[1]")
assert any(l.startswith("h __PYQASM_QUBITS__[2]") for l in gate_lines)
assert any(l.startswith("cx __PYQASM_QUBITS__[3], __PYQASM_QUBITS__[4]") for l in gate_lines)


def test_merge_with_measurements_and_barriers():
# Module A: 1 qubit + classical 1; has barrier and measure
qasm_a = (
"OPENQASM 3.0;\n"
'include "stdgates.inc";\n'
"qubit[1] qa; bit[1] ca;\n"
"h qa[0];\n"
"barrier qa;\n"
"ca[0] = measure qa[0];\n"
)
# Module B: 2 qubits + classical 2
qasm_b = (
"OPENQASM 3.0;\n"
'include "stdgates.inc";\n'
"qubit[2] qb; bit[2] cb;\n"
"x qb[1];\n"
"cb[1] = measure qb[1];\n"
)

mod_a = _qasm3(qasm_a)
mod_b = _qasm3(qasm_b)

merged = mod_a.merge(mod_b)
merged_text = str(merged)

assert "qubit[3] __PYQASM_QUBITS__;" in merged_text
assert "measure __PYQASM_QUBITS__[2];" in merged_text
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we changing the structure of the measurement statements? I believe the merged qasm should be like -

OPENQASM 3.0;
include "stdgates.inc";
qubit[3] __PYQASM_QUBITS__;
// Module A
bit[1] ca;
h __PYQASM_QUBITS__[0];
barrier __PYQASM_QUBITS__[0];
ca[0] = measure __PYQASM_QUBITS__[0];

// Module B
bit[2] cb;
x __PYQASM_QUBITS__[1];
cb[1] = measure __PYQASM_QUBITS__[1];

assert "barrier __PYQASM_QUBITS__" in merged_text
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert "barrier __PYQASM_QUBITS__" in merged_text
assert "barrier __PYQASM_QUBITS__[0]" in merged_text



def test_merge_qasm2_with_qasm3():
qasm2 = (
"OPENQASM 2.0;\n"
'include "qelib1.inc";\n'
"qreg q[1];\n"
"h q[0];\n"
)
qasm3 = (
"OPENQASM 3.0;\n"
'include "stdgates.inc";\n'
"qubit[2] r;\n"
"x r[0];\n"
)

mod2 = loads(qasm2)
mod3 = loads(qasm3)

merged = mod2.merge(mod3)
text = str(merged)
# Since we are merging starting from a QASM2 module, the merged output
# should remain in QASM2 syntax (qreg), not QASM3 (qubit).
assert "OPENQASM 2.0;" in text
assert 'include "qelib1.inc";' in text
assert "qreg __PYQASM_QUBITS__[3];" in text
assert "x __PYQASM_QUBITS__[1];" in text
Loading