Skip to content

Commit cb8fa28

Browse files
david-plweinbe58
andcommitted
Implement squin noise rewrite for depolarize2 (#441)
Another thing that came out of #436. --------- Co-authored-by: Phillip Weinberg <[email protected]>
1 parent c245ec6 commit cb8fa28

File tree

2 files changed

+56
-8
lines changed

2 files changed

+56
-8
lines changed

src/bloqade/squin/noise/rewrite.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
from kirin import ir
44
from kirin.passes import Pass
55
from kirin.rewrite import Walk
6-
from kirin.dialects import ilist
6+
from kirin.dialects import py, ilist
77
from kirin.rewrite.abc import RewriteRule, RewriteResult
88

99
from .stmts import (
1010
PPError,
1111
QubitLoss,
1212
Depolarize,
1313
PauliError,
14+
Depolarize2,
1415
NoiseChannel,
1516
TwoQubitPauliChannel,
1617
SingleQubitPauliChannel,
@@ -58,6 +59,18 @@ def rewrite_single_qubit_pauli_channel(
5859
def rewrite_two_qubit_pauli_channel(
5960
self, node: TwoQubitPauliChannel
6061
) -> RewriteResult:
62+
operator_list = self._insert_two_qubit_paulis_before_node(node)
63+
stochastic_unitary = StochasticUnitaryChannel(
64+
operators=operator_list, probabilities=node.params
65+
)
66+
67+
node.replace_by(stochastic_unitary)
68+
return RewriteResult(has_done_something=True)
69+
70+
@staticmethod
71+
def _insert_two_qubit_paulis_before_node(
72+
node: TwoQubitPauliChannel | Depolarize2,
73+
) -> ir.ResultValue:
6174
paulis = (Identity(sites=1), X(), Y(), Z())
6275
for op in paulis:
6376
op.insert_before(node)
@@ -71,12 +84,7 @@ def rewrite_two_qubit_pauli_channel(
7184
operators.append(op.result)
7285

7386
(operator_list := ilist.New(values=operators)).insert_before(node)
74-
stochastic_unitary = StochasticUnitaryChannel(
75-
operators=operator_list.result, probabilities=node.params
76-
)
77-
78-
node.replace_by(stochastic_unitary)
79-
return RewriteResult(has_done_something=True)
87+
return operator_list.result
8088

8189
def rewrite_p_p_error(self, node: PPError) -> RewriteResult:
8290
(operators := ilist.New(values=(node.op,))).insert_before(node)
@@ -95,8 +103,14 @@ def rewrite_depolarize(self, node: Depolarize) -> RewriteResult:
95103
op.insert_before(node)
96104
operators.append(op.result)
97105

106+
# NOTE: need to divide the probability by 3 to get the correct total error rate
107+
(three := py.Constant(3)).insert_before(node)
108+
(p_over_3 := py.Div(node.p, three.result)).insert_before(node)
109+
98110
(operator_list := ilist.New(values=operators)).insert_before(node)
99-
(ps := ilist.New(values=[node.p for _ in range(3)])).insert_before(node)
111+
(ps := ilist.New(values=[p_over_3.result for _ in range(3)])).insert_before(
112+
node
113+
)
100114

101115
stochastic_unitary = StochasticUnitaryChannel(
102116
operators=operator_list.result, probabilities=ps.result
@@ -105,6 +119,21 @@ def rewrite_depolarize(self, node: Depolarize) -> RewriteResult:
105119

106120
return RewriteResult(has_done_something=True)
107121

122+
def rewrite_depolarize2(self, node: Depolarize2) -> RewriteResult:
123+
operator_list = self._insert_two_qubit_paulis_before_node(node)
124+
125+
# NOTE: need to divide the probability by 15 to get the correct total error rate
126+
(fifteen := py.Constant(15)).insert_before(node)
127+
(p_over_15 := py.Div(node.p, fifteen.result)).insert_before(node)
128+
(probs := ilist.New(values=[p_over_15.result] * 15)).insert_before(node)
129+
130+
stochastic_unitary = StochasticUnitaryChannel(
131+
operators=operator_list, probabilities=probs.result
132+
)
133+
node.replace_by(stochastic_unitary)
134+
135+
return RewriteResult(has_done_something=True)
136+
108137

109138
class RewriteNoiseStmts(Pass):
110139
def unsafe_run(self, mt: ir.Method):

test/pyqrack/squin/test_noise.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,3 +209,22 @@ def main():
209209
print(ket)
210210

211211
assert math.isclose(abs(ket[2]) ** 2, 1.0, abs_tol=1e-5)
212+
213+
214+
def test_depolarize2():
215+
@squin.kernel
216+
def main():
217+
q = squin.qubit.new(2)
218+
err = squin.noise.depolarize2(0.1)
219+
squin.qubit.apply(err, q[0], q[1])
220+
221+
main.print()
222+
223+
main_ = main.similar(main.dialects)
224+
225+
result = RewriteNoiseStmts(main.dialects)(main_)
226+
assert result.has_done_something
227+
main_.print()
228+
229+
sim = StackMemorySimulator(min_qubits=2)
230+
sim.run(main)

0 commit comments

Comments
 (0)