Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/bloqade/analysis/fidelity/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from . import impls as impls
from .analysis import FidelityAnalysis as FidelityAnalysis
3 changes: 3 additions & 0 deletions src/bloqade/analysis/fidelity/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ class FidelityAnalysis(Forward):
"""
This analysis pass can be used to track the global addresses of qubits and wires.

**NOTE**: nested kernels are currently not supported, so instead of calling a kernel
from another kernel please inline it.

## Usage examples

```
Expand Down
78 changes: 78 additions & 0 deletions src/bloqade/analysis/fidelity/impls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import numpy as np
from kirin import interp
from kirin.lattice import EmptyLattice
from kirin.analysis import const
from kirin.dialects import scf
from kirin.dialects.scf import For, Yield, IfElse

from .analysis import FidelityAnalysis


@scf.dialect.register(key="circuit.fidelity")
class ScfFidelityMethodTable(interp.MethodTable):

@interp.impl(IfElse)
def if_else(
self,
interp: FidelityAnalysis,
frame: interp.Frame[EmptyLattice],
stmt: IfElse,
):
# NOTE: store current fidelity for later
current_gate_fidelity = interp._current_gate_fidelity
current_atom_survival = interp._current_atom_survival_probability

for s in stmt.then_body.stmts():
stmt_impl = interp.lookup_registry(frame=frame, stmt=s)
if stmt_impl is None:
continue

stmt_impl(interp=interp, frame=frame, stmt=s)

then_gate_fidelity = interp._current_gate_fidelity
then_atom_survival = interp._current_atom_survival_probability

# NOTE: reset fidelity of interp to check if the else body results in a worse fidelity
interp._current_gate_fidelity = current_gate_fidelity
interp._current_atom_survival_probability = current_atom_survival

for s in stmt.else_body.stmts():
stmt_impl = interp.lookup_registry(frame=frame, stmt=s)
if stmt_impl is None:
continue

Check warning on line 42 in src/bloqade/analysis/fidelity/impls.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/analysis/fidelity/impls.py#L42

Added line #L42 was not covered by tests

stmt_impl(interp=interp, frame=frame, stmt=s)

else_gate_fidelity = interp._current_gate_fidelity
else_atom_survival = interp._current_atom_survival_probability

# NOTE: look for the "worse" branch
then_combined_fidelity = then_gate_fidelity * np.prod(then_atom_survival)
else_combined_fidelity = else_gate_fidelity * np.prod(else_atom_survival)

if then_combined_fidelity < else_combined_fidelity:
interp._current_gate_fidelity = then_gate_fidelity
interp._current_atom_survival_probability = then_atom_survival
else:
interp._current_gate_fidelity = else_gate_fidelity
interp._current_atom_survival_probability = else_atom_survival

Check warning on line 58 in src/bloqade/analysis/fidelity/impls.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/analysis/fidelity/impls.py#L57-L58

Added lines #L57 - L58 were not covered by tests

@interp.impl(Yield)
def yield_(
self, interp: FidelityAnalysis, frame: interp.Frame[EmptyLattice], stmt: Yield
):
# NOTE: yield can by definition only contain values, never any stmts, so fidelity cannot decrease
return

@interp.impl(For)
def for_loop(
self, interp: FidelityAnalysis, frame: interp.Frame[EmptyLattice], stmt: For
):
if not isinstance(hint := stmt.iterable.hints.get("const"), const.Value):

Check warning on line 71 in src/bloqade/analysis/fidelity/impls.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/analysis/fidelity/impls.py#L71

Added line #L71 was not covered by tests
# NOTE: not clear how long this loop is
# TODO: should we at least count the body once?
return

Check warning on line 74 in src/bloqade/analysis/fidelity/impls.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/analysis/fidelity/impls.py#L74

Added line #L74 was not covered by tests

for _ in hint.data:
for s in stmt.body.stmts():
interp.eval_stmt(frame=frame, stmt=s)

Check warning on line 78 in src/bloqade/analysis/fidelity/impls.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/analysis/fidelity/impls.py#L76-L78

Added lines #L76 - L78 were not covered by tests
98 changes: 94 additions & 4 deletions test/analysis/fidelity/test_fidelity.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ def parallel_cz_errors(self, ctrls, qargs, rest):
return {(0.01, 0.01, 0.01, 0.01): ctrls + qargs + rest}


@pytest.mark.xfail
def test_if():

@qasm2.extended
Expand Down Expand Up @@ -115,7 +114,7 @@ def main_if():
p_loss = 0.01

model = NoiseTestModel(
global_loss_prob=p_loss,
local_loss_prob=p_loss,
global_px=px,
global_py=py,
global_pz=pz,
Expand All @@ -125,7 +124,6 @@ def main_if():
fid_analysis = FidelityAnalysis(main.dialects)
fid_analysis.run_analysis(main, no_raise=False)

model = NoiseTestModel()
NoisePass(main_if.dialects, noise_model=model)(main_if)
fid_if_analysis = FidelityAnalysis(main_if.dialects)
fid_if_analysis.run_analysis(main_if, no_raise=False)
Expand All @@ -139,7 +137,7 @@ def main_if():
)


@pytest.mark.xfail
@pytest.mark.xfail # NOTE: currently fails because of kirin issue #408
def test_for():

@qasm2.extended
Expand Down Expand Up @@ -203,3 +201,95 @@ def main_for():
== fid_analysis.atom_survival_probability[0]
< 1
)


# NOTE: nested kernels currently not supported (AddressAnalysis doesn't support it and we need that for both the NoisePass and the FidelityAnalysis)
@pytest.mark.xfail
def test_nested_kernel():
@qasm2.extended
def nested():
q = qasm2.qreg(2)
qasm2.h(q[0])

qasm2.cx(q[0], q[1])

return q

@qasm2.extended
def main():
q = nested()
return q

px = 0.01
py = 0.01
pz = 0.01
p_loss = 0.01

model = NoiseTestModel(
global_loss_prob=p_loss,
global_px=px,
global_py=py,
global_pz=pz,
local_px=0.002,
)

NoisePass(main.dialects, noise_model=model).unsafe_run(nested)

fid_nested = FidelityAnalysis(main.dialects)
fid_nested.run_analysis(nested, no_raise=False)

fid_main = FidelityAnalysis(main.dialects)
fid_main.run_analysis(main, no_raise=False)

assert fid_main.gate_fidelity == fid_nested.gate_fidelity
assert fid_main.atom_survival_probability == fid_nested.atom_survival_probability


@pytest.mark.xfail
def test_nested_kernel_with_more_stmts():
@qasm2.extended
def nested():
q = qasm2.qreg(2)
qasm2.h(q[0])

qasm2.cx(q[0], q[1])

return q

@qasm2.extended
def main():
q = nested()

qasm2.h(q[0])

qasm2.cx(q[0], q[1])
return q

px = 0.01
py = 0.01
pz = 0.01
p_loss = 0.01

model = NoiseTestModel(
global_loss_prob=p_loss,
global_px=px,
global_py=py,
global_pz=pz,
local_px=0.002,
)

noise_pass = NoisePass(main.dialects, noise_model=model)

noise_pass(main)
noise_pass(nested)

fid_nested = FidelityAnalysis(main.dialects)
fid_nested.run_analysis(nested, no_raise=False)

fid_main = FidelityAnalysis(main.dialects)
fid_main.run_analysis(main, no_raise=False)

assert fid_main.gate_fidelity == fid_nested.gate_fidelity**2
assert fid_main.atom_survival_probability == [
prob**2 for prob in fid_nested.atom_survival_probability
]