Skip to content

Commit 372ec83

Browse files
david-plweinbe58
andauthored
Implement squin noise rewrite for depolarize2 (#441)
Another thing that came out of #436. --------- Co-authored-by: Phillip Weinberg <[email protected]>
1 parent 339e60d commit 372ec83

File tree

2 files changed

+56
-8
lines changed

2 files changed

+56
-8
lines changed

src/bloqade/squin/noise/rewrite.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33
from kirin import ir
44
from kirin.passes import Pass
55
from kirin.rewrite import Walk
6-
from kirin.dialects import ilist
6+
from kirin.dialects import py, ilist
77
from kirin.rewrite.abc import RewriteRule, RewriteResult
88

99
from .stmts import (
1010
QubitLoss,
1111
Depolarize,
1212
PauliError,
13+
Depolarize2,
1314
NoiseChannel,
1415
TwoQubitPauliChannel,
1516
SingleQubitPauliChannel,
@@ -57,6 +58,18 @@ def rewrite_single_qubit_pauli_channel(
5758
def rewrite_two_qubit_pauli_channel(
5859
self, node: TwoQubitPauliChannel
5960
) -> RewriteResult:
61+
operator_list = self._insert_two_qubit_paulis_before_node(node)
62+
stochastic_unitary = StochasticUnitaryChannel(
63+
operators=operator_list, probabilities=node.params
64+
)
65+
66+
node.replace_by(stochastic_unitary)
67+
return RewriteResult(has_done_something=True)
68+
69+
@staticmethod
70+
def _insert_two_qubit_paulis_before_node(
71+
node: TwoQubitPauliChannel | Depolarize2,
72+
) -> ir.ResultValue:
6073
paulis = (Identity(sites=1), X(), Y(), Z())
6174
for op in paulis:
6275
op.insert_before(node)
@@ -70,12 +83,7 @@ def rewrite_two_qubit_pauli_channel(
7083
operators.append(op.result)
7184

7285
(operator_list := ilist.New(values=operators)).insert_before(node)
73-
stochastic_unitary = StochasticUnitaryChannel(
74-
operators=operator_list.result, probabilities=node.params
75-
)
76-
77-
node.replace_by(stochastic_unitary)
78-
return RewriteResult(has_done_something=True)
86+
return operator_list.result
7987

8088
def rewrite_depolarize(self, node: Depolarize) -> RewriteResult:
8189
paulis = (X(), Y(), Z())
@@ -84,8 +92,14 @@ def rewrite_depolarize(self, node: Depolarize) -> RewriteResult:
8492
op.insert_before(node)
8593
operators.append(op.result)
8694

95+
# NOTE: need to divide the probability by 3 to get the correct total error rate
96+
(three := py.Constant(3)).insert_before(node)
97+
(p_over_3 := py.Div(node.p, three.result)).insert_before(node)
98+
8799
(operator_list := ilist.New(values=operators)).insert_before(node)
88-
(ps := ilist.New(values=[node.p for _ in range(3)])).insert_before(node)
100+
(ps := ilist.New(values=[p_over_3.result for _ in range(3)])).insert_before(
101+
node
102+
)
89103

90104
stochastic_unitary = StochasticUnitaryChannel(
91105
operators=operator_list.result, probabilities=ps.result
@@ -94,6 +108,21 @@ def rewrite_depolarize(self, node: Depolarize) -> RewriteResult:
94108

95109
return RewriteResult(has_done_something=True)
96110

111+
def rewrite_depolarize2(self, node: Depolarize2) -> RewriteResult:
112+
operator_list = self._insert_two_qubit_paulis_before_node(node)
113+
114+
# NOTE: need to divide the probability by 15 to get the correct total error rate
115+
(fifteen := py.Constant(15)).insert_before(node)
116+
(p_over_15 := py.Div(node.p, fifteen.result)).insert_before(node)
117+
(probs := ilist.New(values=[p_over_15.result] * 15)).insert_before(node)
118+
119+
stochastic_unitary = StochasticUnitaryChannel(
120+
operators=operator_list, probabilities=probs.result
121+
)
122+
node.replace_by(stochastic_unitary)
123+
124+
return RewriteResult(has_done_something=True)
125+
97126

98127
class RewriteNoiseStmts(Pass):
99128
def unsafe_run(self, mt: ir.Method):

test/pyqrack/squin/test_noise.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,3 +180,22 @@ def main():
180180
print(ket)
181181

182182
assert math.isclose(abs(ket[2]) ** 2, 1.0, abs_tol=1e-5)
183+
184+
185+
def test_depolarize2():
186+
@squin.kernel
187+
def main():
188+
q = squin.qubit.new(2)
189+
err = squin.noise.depolarize2(0.1)
190+
squin.qubit.apply(err, q[0], q[1])
191+
192+
main.print()
193+
194+
main_ = main.similar(main.dialects)
195+
196+
result = RewriteNoiseStmts(main.dialects)(main_)
197+
assert result.has_done_something
198+
main_.print()
199+
200+
sim = StackMemorySimulator(min_qubits=2)
201+
sim.run(main)

0 commit comments

Comments
 (0)