3
3
from kirin import ir
4
4
from kirin .passes import Pass
5
5
from kirin .rewrite import Walk
6
- from kirin .dialects import ilist
6
+ from kirin .dialects import py , ilist
7
7
from kirin .rewrite .abc import RewriteRule , RewriteResult
8
8
9
9
from .stmts import (
10
10
PPError ,
11
11
QubitLoss ,
12
12
Depolarize ,
13
13
PauliError ,
14
+ Depolarize2 ,
14
15
NoiseChannel ,
15
16
TwoQubitPauliChannel ,
16
17
SingleQubitPauliChannel ,
@@ -58,6 +59,18 @@ def rewrite_single_qubit_pauli_channel(
58
59
def rewrite_two_qubit_pauli_channel (
59
60
self , node : TwoQubitPauliChannel
60
61
) -> RewriteResult :
62
+ operator_list = self ._insert_two_qubit_paulis_before_node (node )
63
+ stochastic_unitary = StochasticUnitaryChannel (
64
+ operators = operator_list , probabilities = node .params
65
+ )
66
+
67
+ node .replace_by (stochastic_unitary )
68
+ return RewriteResult (has_done_something = True )
69
+
70
+ @staticmethod
71
+ def _insert_two_qubit_paulis_before_node (
72
+ node : TwoQubitPauliChannel | Depolarize2 ,
73
+ ) -> ir .ResultValue :
61
74
paulis = (Identity (sites = 1 ), X (), Y (), Z ())
62
75
for op in paulis :
63
76
op .insert_before (node )
@@ -71,12 +84,7 @@ def rewrite_two_qubit_pauli_channel(
71
84
operators .append (op .result )
72
85
73
86
(operator_list := ilist .New (values = operators )).insert_before (node )
74
- stochastic_unitary = StochasticUnitaryChannel (
75
- operators = operator_list .result , probabilities = node .params
76
- )
77
-
78
- node .replace_by (stochastic_unitary )
79
- return RewriteResult (has_done_something = True )
87
+ return operator_list .result
80
88
81
89
def rewrite_p_p_error (self , node : PPError ) -> RewriteResult :
82
90
(operators := ilist .New (values = (node .op ,))).insert_before (node )
@@ -95,8 +103,14 @@ def rewrite_depolarize(self, node: Depolarize) -> RewriteResult:
95
103
op .insert_before (node )
96
104
operators .append (op .result )
97
105
106
+ # NOTE: need to divide the probability by 3 to get the correct total error rate
107
+ (three := py .Constant (3 )).insert_before (node )
108
+ (p_over_3 := py .Div (node .p , three .result )).insert_before (node )
109
+
98
110
(operator_list := ilist .New (values = operators )).insert_before (node )
99
- (ps := ilist .New (values = [node .p for _ in range (3 )])).insert_before (node )
111
+ (ps := ilist .New (values = [p_over_3 .result for _ in range (3 )])).insert_before (
112
+ node
113
+ )
100
114
101
115
stochastic_unitary = StochasticUnitaryChannel (
102
116
operators = operator_list .result , probabilities = ps .result
@@ -105,6 +119,21 @@ def rewrite_depolarize(self, node: Depolarize) -> RewriteResult:
105
119
106
120
return RewriteResult (has_done_something = True )
107
121
122
+ def rewrite_depolarize2 (self , node : Depolarize2 ) -> RewriteResult :
123
+ operator_list = self ._insert_two_qubit_paulis_before_node (node )
124
+
125
+ # NOTE: need to divide the probability by 15 to get the correct total error rate
126
+ (fifteen := py .Constant (15 )).insert_before (node )
127
+ (p_over_15 := py .Div (node .p , fifteen .result )).insert_before (node )
128
+ (probs := ilist .New (values = [p_over_15 .result ] * 15 )).insert_before (node )
129
+
130
+ stochastic_unitary = StochasticUnitaryChannel (
131
+ operators = operator_list , probabilities = probs .result
132
+ )
133
+ node .replace_by (stochastic_unitary )
134
+
135
+ return RewriteResult (has_done_something = True )
136
+
108
137
109
138
class RewriteNoiseStmts (Pass ):
110
139
def unsafe_run (self , mt : ir .Method ):
0 commit comments