3
3
from kirin import ir
4
4
from kirin .passes import Pass
5
5
from kirin .rewrite import Walk
6
- from kirin .dialects import ilist
6
+ from kirin .dialects import py , ilist
7
7
from kirin .rewrite .abc import RewriteRule , RewriteResult
8
8
9
9
from .stmts import (
10
10
QubitLoss ,
11
11
Depolarize ,
12
12
PauliError ,
13
+ Depolarize2 ,
13
14
NoiseChannel ,
14
15
TwoQubitPauliChannel ,
15
16
SingleQubitPauliChannel ,
@@ -57,6 +58,18 @@ def rewrite_single_qubit_pauli_channel(
57
58
def rewrite_two_qubit_pauli_channel (
58
59
self , node : TwoQubitPauliChannel
59
60
) -> RewriteResult :
61
+ operator_list = self ._insert_two_qubit_paulis_before_node (node )
62
+ stochastic_unitary = StochasticUnitaryChannel (
63
+ operators = operator_list , probabilities = node .params
64
+ )
65
+
66
+ node .replace_by (stochastic_unitary )
67
+ return RewriteResult (has_done_something = True )
68
+
69
+ @staticmethod
70
+ def _insert_two_qubit_paulis_before_node (
71
+ node : TwoQubitPauliChannel | Depolarize2 ,
72
+ ) -> ir .ResultValue :
60
73
paulis = (Identity (sites = 1 ), X (), Y (), Z ())
61
74
for op in paulis :
62
75
op .insert_before (node )
@@ -70,12 +83,7 @@ def rewrite_two_qubit_pauli_channel(
70
83
operators .append (op .result )
71
84
72
85
(operator_list := ilist .New (values = operators )).insert_before (node )
73
- stochastic_unitary = StochasticUnitaryChannel (
74
- operators = operator_list .result , probabilities = node .params
75
- )
76
-
77
- node .replace_by (stochastic_unitary )
78
- return RewriteResult (has_done_something = True )
86
+ return operator_list .result
79
87
80
88
def rewrite_depolarize (self , node : Depolarize ) -> RewriteResult :
81
89
paulis = (X (), Y (), Z ())
@@ -84,8 +92,14 @@ def rewrite_depolarize(self, node: Depolarize) -> RewriteResult:
84
92
op .insert_before (node )
85
93
operators .append (op .result )
86
94
95
+ # NOTE: need to divide the probability by 3 to get the correct total error rate
96
+ (three := py .Constant (3 )).insert_before (node )
97
+ (p_over_3 := py .Div (node .p , three .result )).insert_before (node )
98
+
87
99
(operator_list := ilist .New (values = operators )).insert_before (node )
88
- (ps := ilist .New (values = [node .p for _ in range (3 )])).insert_before (node )
100
+ (ps := ilist .New (values = [p_over_3 .result for _ in range (3 )])).insert_before (
101
+ node
102
+ )
89
103
90
104
stochastic_unitary = StochasticUnitaryChannel (
91
105
operators = operator_list .result , probabilities = ps .result
@@ -94,6 +108,21 @@ def rewrite_depolarize(self, node: Depolarize) -> RewriteResult:
94
108
95
109
return RewriteResult (has_done_something = True )
96
110
111
+ def rewrite_depolarize2 (self , node : Depolarize2 ) -> RewriteResult :
112
+ operator_list = self ._insert_two_qubit_paulis_before_node (node )
113
+
114
+ # NOTE: need to divide the probability by 15 to get the correct total error rate
115
+ (fifteen := py .Constant (15 )).insert_before (node )
116
+ (p_over_15 := py .Div (node .p , fifteen .result )).insert_before (node )
117
+ (probs := ilist .New (values = [p_over_15 .result ] * 15 )).insert_before (node )
118
+
119
+ stochastic_unitary = StochasticUnitaryChannel (
120
+ operators = operator_list , probabilities = probs .result
121
+ )
122
+ node .replace_by (stochastic_unitary )
123
+
124
+ return RewriteResult (has_done_something = True )
125
+
97
126
98
127
class RewriteNoiseStmts (Pass ):
99
128
def unsafe_run (self , mt : ir .Method ):
0 commit comments