Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from .compiler_transform import compiler_transform
from .passes import PLModulePass
from .pattern_rewriter import PLPatternRewriter


class CancelInverses(PLModulePass):
Expand Down Expand Up @@ -58,7 +59,7 @@ def can_cancel(op: quantum.CustomOp, next_op: Operation) -> bool:


@CancelInverses.rewrite_rule(quantum.CustomOp)
def rewrite_custom_op(self, op, rewriter):
def rewrite_custom_op(self, op: quantum.CustomOp, rewriter: PLPatternRewriter):
"""Rewrite rule for CustomOp."""
while isinstance(op, quantum.CustomOp) and op.gate_name.data in self.self_inverses:
next_user = None
Expand All @@ -71,31 +72,27 @@ def rewrite_custom_op(self, op, rewriter):
if next_user is None:
break

for q1, q2 in zip(op.in_qubits, next_user.out_qubits, strict=True):
rewriter.replace_all_uses_with(q2, q1)
for cq1, cq2 in zip(op.in_ctrl_qubits, next_user.out_ctrl_qubits, strict=True):
rewriter.replace_all_uses_with(cq2, cq1)
rewriter.erase_op(next_user)
rewriter.erase_op(op)
rewriter.erase_gate(next_user)
rewriter.erase_gate(op)
op = op.in_qubits[0].owner


# We can register more rewrite rules as needed. Here are some
# dummy rewrite rules to illustrate:
@CancelInverses.rewrite_rule(quantum.InsertOp)
def rewrite_insert_op(self, op, rewriter):
def rewrite_insert_op(self, op: quantum.InsertOp, rewriter: PLPatternRewriter):
"""Rewrite rule for InsertOp."""
return


@CancelInverses.rewrite_rule(quantum.ExtractOp)
def rewrite_extract_op(self, op, rewriter):
def rewrite_extract_op(self, op: quantum.ExtractOp, rewriter: PLPatternRewriter):
"""Rewrite rule for ExtractOp."""
return


@CancelInverses.rewrite_rule(quantum.MeasureOp)
def rewrite_mid_measure_op(self, op, rewriter):
def rewrite_mid_measure_op(self, op: quantum.MeasureOp, rewriter: PLPatternRewriter):
"""Rewrite rule for MeasureOp."""
return

Expand Down
14 changes: 7 additions & 7 deletions frontend/catalyst/python_interface/pass_api/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
GreedyRewritePatternApplier,
PatternRewriter,
PatternRewriteWalker,
RewritePattern,
op_type_rewrite_pattern,
)

from .pattern_rewriter import PLPatternRewriter, PLPatternRewriteWalker


def _update_type_hints(hint: type[Operation] | type[Operation]) -> Callable:
"""Update the signature of a ``match_and_rewrite`` method to use the provided operation
Expand Down Expand Up @@ -77,7 +77,7 @@ def __init__(self, _pass):

@op_type_rewrite_pattern
@_update_type_hints(hint)
def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter) -> None:
def match_and_rewrite(self, op: Operation, rewriter: PLPatternRewriter) -> None:
rewrite_rule(self._pass, op, rewriter)

return _RewritePattern
Expand Down Expand Up @@ -120,7 +120,7 @@ def rewrite_rule(
.. code-block:: python

@PLModulePass.rewrite_rule(MyOperation)
def rewrite_myop(self, op: MyOperation, rewriter: PatternRewriter) -> None:
def rewrite_myop(self, op: MyOperation, rewriter: PLPatternRewriter) -> None:
...

.. note::
Expand All @@ -135,7 +135,7 @@ def rewrite_myop(self, op: MyOperation, rewriter: PatternRewriter) -> None:
Callable: a decorator to register the rewrite rule with the ModulePass
"""

def decorator(rule: Callable[[Operation, PatternRewriter], None]) -> Callable:
def decorator(rule: Callable[[Operation, PLPatternRewriter], None]) -> Callable:
rewrite_pattern = _create_rewrite_pattern(hint, rule)
cls._rewrite_patterns[hint] = rewrite_pattern
return rule
Expand Down Expand Up @@ -170,10 +170,10 @@ def apply(self, ctx: Context, op: builtin.ModuleOp) -> None: # pylint: disable=
pattern = GreedyRewritePatternApplier(
rewrite_patterns=[rp(self) for rp in self._rewrite_patterns.values()]
)
walker = PatternRewriteWalker(pattern=pattern, apply_recursively=self.recursive)
walker = PLPatternRewriteWalker(pattern=pattern, apply_recursively=self.recursive)
walker.rewrite_module(op)

else:
for rp in self._rewrite_patterns.values():
walker = PatternRewriteWalker(pattern=rp(self), apply_recursively=self.recursive)
walker = PLPatternRewriteWalker(pattern=rp(self), apply_recursively=self.recursive)
walker.rewrite_module(op)
Loading
Loading