Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/unroll_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: disable=invalid-name
# pylint: disable=invalid-name, cyclic-import

"""
Script demonstrating how to unroll a QASM 3 program using pyqasm.
Expand Down
73 changes: 73 additions & 0 deletions src/pyqasm/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from __future__ import annotations

import functools
import re
from abc import ABC, abstractmethod
from collections import Counter
from copy import deepcopy
Expand Down Expand Up @@ -761,3 +762,75 @@ def accept(self, visitor):
Args:
visitor (QasmVisitor): The visitor to accept
"""

@abstractmethod
def merge(
self,
other: "QasmModule",
device_qubits: Optional[int] = None,
) -> "QasmModule":
"""Merge this module with another module.

Implemented by concrete subclasses to avoid version mixing and
import-time cycles. Implementations should ensure both operands
are normalized to the same version prior to merging.
"""


def offset_statement_qubits(
stmt: qasm3_ast.Statement, offset: int
): # pylint: disable=too-many-branches
"""Offset qubit indices for a given statement in-place by ``offset``.
Handles gates, measurements, resets, and barriers (including slice forms).
"""
if isinstance(stmt, qasm3_ast.QuantumMeasurementStatement):
bit = stmt.measure.qubit
if isinstance(bit, qasm3_ast.IndexedIdentifier):
for group in bit.indices:
for ind in group:
ind.value += offset # type: ignore[attr-defined]
return

if isinstance(stmt, qasm3_ast.QuantumGate):
for q in stmt.qubits:
for group in q.indices:
for ind in group:
ind.value += offset # type: ignore[attr-defined]
return

if isinstance(stmt, qasm3_ast.QuantumReset):
q = stmt.qubits
if isinstance(q, qasm3_ast.IndexedIdentifier):
for group in q.indices:
for ind in group:
ind.value += offset # type: ignore[attr-defined]
return

if isinstance(stmt, qasm3_ast.QuantumBarrier):
qubits = stmt.qubits
if len(qubits) == 0:
return
first = qubits[0]
if isinstance(first, qasm3_ast.IndexedIdentifier):
for group in first.indices:
for ind in group:
ind.value += offset # type: ignore[attr-defined]
elif isinstance(first, qasm3_ast.Identifier):
# Handle forms: __PYQASM_QUBITS__[:E], [S:], [S:E]
name = first.name
if name.startswith("__PYQASM_QUBITS__[") and name.endswith("]"):
slice_str = name[len("__PYQASM_QUBITS__") :]
# Parse slice forms [S:E], [:E], or [S:]
m = re.match(r"\[(?:(\d+)?:(\d+)?)\]", slice_str)
if m:
start_s, end_s = m.group(1), m.group(2)
if start_s is None and end_s is not None:
end_v = int(end_s) + offset
first.name = f"__PYQASM_QUBITS__[:{end_v}]"
elif start_s is not None and end_s is None:
start_v = int(start_s) + offset
first.name = f"__PYQASM_QUBITS__[{start_v}:]"
elif start_s is not None and end_s is not None:
start_v = int(start_s) + offset
end_v = int(end_s) + offset
first.name = f"__PYQASM_QUBITS__[{start_v}:{end_v}]"
130 changes: 130 additions & 0 deletions src/pyqasm/modules/qasm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,62 @@
from pyqasm.modules.base import QasmModule
from pyqasm.modules.qasm3 import Qasm3Module

try:
from pyqasm.modules.base import offset_statement_qubits # type: ignore
except Exception: # pylint: disable=broad-except

def offset_statement_qubits(stmt: qasm3_ast.Statement, offset: int): # type: ignore[override] # pylint: disable=too-many-branches
"""Offset qubit indices for a given statement in-place by ``offset``."""
if isinstance(stmt, qasm3_ast.QuantumMeasurementStatement):
bit = stmt.measure.qubit
if isinstance(bit, qasm3_ast.IndexedIdentifier):
for group in bit.indices:
for ind in group:
ind.value += offset # type: ignore[attr-defined]
return

if isinstance(stmt, qasm3_ast.QuantumGate):
for q in stmt.qubits:
for group in q.indices:
for ind in group:
ind.value += offset # type: ignore[attr-defined]
return

if isinstance(stmt, qasm3_ast.QuantumReset):
q = stmt.qubits
if isinstance(q, qasm3_ast.IndexedIdentifier):
for group in q.indices:
for ind in group:
ind.value += offset # type: ignore[attr-defined]
return

if isinstance(stmt, qasm3_ast.QuantumBarrier):
qubits = stmt.qubits
if len(qubits) == 0:
return
first = qubits[0]
if isinstance(first, qasm3_ast.IndexedIdentifier):
for group in first.indices:
for ind in group:
ind.value += offset # type: ignore[attr-defined]
elif isinstance(first, qasm3_ast.Identifier):
name = first.name
if name.startswith("__PYQASM_QUBITS__[") and name.endswith("]"):
slice_str = name[len("__PYQASM_QUBITS__") :]
m = re.match(r"\[(?:(\d+)?:(\d+)?)\]", slice_str)
if m:
start_s, end_s = m.group(1), m.group(2)
if start_s is None and end_s is not None:
end_v = int(end_s) + offset
first.name = f"__PYQASM_QUBITS__[:{end_v}]"
elif start_s is not None and end_s is None:
start_v = int(start_s) + offset
first.name = f"__PYQASM_QUBITS__[{start_v}:]"
elif start_s is not None and end_s is not None:
start_v = int(start_s) + offset
end_v = int(end_s) + offset
first.name = f"__PYQASM_QUBITS__[{start_v}:{end_v}]"


class Qasm2Module(QasmModule):
"""
Expand Down Expand Up @@ -108,3 +164,77 @@ def accept(self, visitor):
final_stmt_list = visitor.finalize(unrolled_stmt_list)

self.unrolled_ast.statements = final_stmt_list

def merge(self, other: QasmModule, device_qubits: int | None = None) -> QasmModule:
"""Merge two modules and return a QASM2 result without mixing versions.
- If ``other`` is QASM3, it is merged into this module's semantics, and
any standard gate includes are mapped to ``qelib1.inc``.
- The merged program keeps version "2.0" and prints as QASM2.
"""
if not isinstance(other, QasmModule):
raise TypeError(f"Expected QasmModule instance, got {type(other).__name__}")

left_mod = self.copy()
right_mod = other.copy()

# Unroll with qubit consolidation so both sides use __PYQASM_QUBITS__
unroll_kwargs: dict[str, object] = {"consolidate_qubits": True}
if device_qubits is not None:
unroll_kwargs["device_qubits"] = device_qubits

left_mod.unroll(**unroll_kwargs)
right_mod.unroll(**unroll_kwargs)

left_qubits = left_mod.num_qubits

merged_program = Program(statements=[], version="2.0")

# Unique includes first; map stdgates.inc -> qelib1.inc for QASM2
include_names: set[str] = set()
for module in (left_mod, right_mod):
for stmt in module.unrolled_ast.statements:
if isinstance(stmt, Include):
fname = stmt.filename
if fname == "stdgates.inc":
fname = "qelib1.inc"
include_names.add(fname)
for fname in include_names:
merged_program.statements.append(Include(filename=fname))

# Consolidated qubit declaration (converted to qreg on print)
merged_program.statements.append(
qasm3_ast.QubitDeclaration(
size=qasm3_ast.IntegerLiteral(value=left_qubits + right_mod.num_qubits),
qubit=qasm3_ast.Identifier(name="__PYQASM_QUBITS__"),
)
)

# Append left ops (skip decls and includes)
for stmt in left_mod.unrolled_ast.statements:
if isinstance(stmt, (qasm3_ast.QubitDeclaration, Include)):
continue
merged_program.statements.append(deepcopy(stmt))

# Append right ops with index offset
for stmt in right_mod.unrolled_ast.statements:
if isinstance(stmt, (qasm3_ast.QubitDeclaration, Include)):
continue
stmt = deepcopy(stmt)
offset_statement_qubits(stmt, left_qubits)
merged_program.statements.append(stmt)

merged_module = Qasm2Module(
name=f"{left_mod.name}_merged_{right_mod.name}",
program=merged_program,
)
merged_module.unrolled_ast = Program(
statements=list(merged_program.statements),
version="2.0",
)
merged_module._external_gates = list(
{*left_mod._external_gates, *right_mod._external_gates}
)
merged_module._user_operations = list(left_mod.history) + list(right_mod.history)
merged_module._user_operations.append(f"merge(other={right_mod.name})")
merged_module.validate()
return merged_module
136 changes: 136 additions & 0 deletions src/pyqasm/modules/qasm3.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,71 @@
Defines a module for handling OpenQASM 3.0 programs.
"""

import re

import openqasm3.ast as qasm3_ast
from openqasm3.ast import Program
from openqasm3.printer import dumps

from pyqasm.modules.base import QasmModule

# Backward-compat: older installed versions may not export offset_statement_qubits
try:
from pyqasm.modules.base import offset_statement_qubits # type: ignore
except Exception: # pylint: disable=broad-except

def offset_statement_qubits(stmt: qasm3_ast.Statement, offset: int): # type: ignore[override] # pylint: disable=too-many-branches
"""Offset qubit indices for a given statement in-place by ``offset``."""
if isinstance(stmt, qasm3_ast.QuantumMeasurementStatement):
bit = stmt.measure.qubit
if isinstance(bit, qasm3_ast.IndexedIdentifier):
for group in bit.indices:
for ind in group:
ind.value += offset # type: ignore[attr-defined]
return

if isinstance(stmt, qasm3_ast.QuantumGate):
for q in stmt.qubits:
for group in q.indices:
for ind in group:
ind.value += offset # type: ignore[attr-defined]
return

if isinstance(stmt, qasm3_ast.QuantumReset):
q = stmt.qubits
if isinstance(q, qasm3_ast.IndexedIdentifier):
for group in q.indices:
for ind in group:
ind.value += offset # type: ignore[attr-defined]
return

if isinstance(stmt, qasm3_ast.QuantumBarrier):
qubits = stmt.qubits
if len(qubits) == 0:
return
first = qubits[0]
if isinstance(first, qasm3_ast.IndexedIdentifier):
for group in first.indices:
for ind in group:
ind.value += offset # type: ignore[attr-defined]
elif isinstance(first, qasm3_ast.Identifier):
name = first.name
if name.startswith("__PYQASM_QUBITS__[") and name.endswith("]"):
slice_str = name[len("__PYQASM_QUBITS__") :]
m = re.match(r"\[(?:(\d+)?:(\d+)?)\]", slice_str)
if m:
start_s, end_s = m.group(1), m.group(2)
if start_s is None and end_s is not None:
end_v = int(end_s) + offset
first.name = f"__PYQASM_QUBITS__[:{end_v}]"
elif start_s is not None and end_s is None:
start_v = int(start_s) + offset
first.name = f"__PYQASM_QUBITS__[{start_v}:]"
elif start_s is not None and end_s is not None:
start_v = int(start_s) + offset
end_v = int(end_s) + offset
first.name = f"__PYQASM_QUBITS__[{start_v}:{end_v}]"


class Qasm3Module(QasmModule):
"""
Expand Down Expand Up @@ -52,3 +112,79 @@ def accept(self, visitor):
final_stmt_list = visitor.finalize(unrolled_stmt_list)

self._unrolled_ast.statements = final_stmt_list

def merge(self, other: QasmModule, device_qubits: int | None = None) -> QasmModule:
"""Merge two modules as OpenQASM 3.0 without mixing versions.
If ``other`` is QASM2, it will be converted to QASM3 before merging.
The merged program keeps version "3.0".
"""
if not isinstance(other, QasmModule):
raise TypeError(f"Expected QasmModule instance, got {type(other).__name__}")

# Convert right to QASM3 if it supports conversion; otherwise copy
convert = getattr(other, "to_qasm3", None)
right_mod = (
convert(as_str=False) if callable(convert) else other.copy()
) # type: ignore[assignment]

left_mod = self.copy()

# Unroll with consolidation so both use __PYQASM_QUBITS__
unroll_kwargs: dict[str, object] = {"consolidate_qubits": True}
if device_qubits is not None:
unroll_kwargs["device_qubits"] = device_qubits

left_mod.unroll(**unroll_kwargs)
right_mod.unroll(**unroll_kwargs)

left_qubits = left_mod.num_qubits
total_qubits = left_qubits + right_mod.num_qubits

merged_program = Program(statements=[], version="3.0")

# Unique includes first
include_names: set[str] = set()
for module in (left_mod, right_mod):
for stmt in module.unrolled_ast.statements:
if isinstance(stmt, qasm3_ast.Include):
include_names.add(stmt.filename)
for fname in include_names:
merged_program.statements.append(qasm3_ast.Include(filename=fname))

# Consolidated qubit declaration
merged_program.statements.append(
qasm3_ast.QubitDeclaration(
size=qasm3_ast.IntegerLiteral(value=total_qubits),
qubit=qasm3_ast.Identifier(name="__PYQASM_QUBITS__"),
)
)

# Append left ops
for stmt in left_mod.unrolled_ast.statements:
if isinstance(stmt, (qasm3_ast.QubitDeclaration, qasm3_ast.Include)):
continue
merged_program.statements.append(stmt)

# Append right ops with index offset
for stmt in right_mod.unrolled_ast.statements:
if isinstance(stmt, (qasm3_ast.QubitDeclaration, qasm3_ast.Include)):
continue
# right_mod is a copy, so it's safe to modify statements in place
offset_statement_qubits(stmt, left_qubits)
merged_program.statements.append(stmt)

merged_module = Qasm3Module(
name=f"{left_mod.name}_merged_{right_mod.name}",
program=merged_program,
)
merged_module.unrolled_ast = Program(
statements=list(merged_program.statements),
version="3.0",
)
merged_module._external_gates = list(
{*left_mod._external_gates, *right_mod._external_gates}
)
merged_module._user_operations = list(left_mod.history) + list(right_mod.history)
merged_module._user_operations.append(f"merge(other={right_mod.name})")
merged_module.validate()
return merged_module
Loading
Loading