Skip to content
22 changes: 21 additions & 1 deletion src/bloqade/analysis/address/impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@
qubit.address method table for a few builtin dialects.
"""

from kirin import interp
from kirin import types, interp
from kirin.analysis import ForwardFrame, const
from kirin.dialects import cf, py, scf, func, ilist

from bloqade.types import QubitType

from .lattice import (
Address,
NotQubit,
AddressReg,
AnyAddress,
AddressQubit,
AddressTuple,
)
Expand Down Expand Up @@ -53,6 +56,21 @@ def new_ilist(
):
return (AddressTuple(frame.get_values(stmt.values)),)

interp.impl(ilist.Map)

def _map(
self,
interp: AddressAnalysis,
frame: interp.Frame,
stmt: ilist.Map,
):
if not stmt.result.type.is_subseteq(ilist.IListType[QubitType, types.Any]):
return (NotQubit(),)

# TODO: try to go into function call if possible
# we need to update the lattice to accept python constants
return (AnyAddress(),)


@py.list.dialect.register(key="qubit.address")
class PyList(interp.MethodTable):
Expand Down Expand Up @@ -119,6 +137,8 @@ def return_(self, _: AddressAnalysis, frame: interp.Frame, stmt: func.Return):
# TODO: replace with the generic implementation
@interp.impl(func.Invoke)
def invoke(self, interp_: AddressAnalysis, frame: interp.Frame, stmt: func.Invoke):
print("Invoke:", stmt.callee)
stmt.callee.code.print()
_, ret = interp_.run_method(
stmt.callee,
interp_.permute_values(
Expand Down
8 changes: 8 additions & 0 deletions src/bloqade/pyqrack/squin/qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@ def new(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.New):
)
return (qreg,)

@interp.impl(qubit.NewQubit)
def new_qubit(
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.NewQubit
):
(addr,) = interp.memory.allocate(1)
qb = PyQrackQubit(addr, interp.memory.sim_reg, QubitState.Active)
return (qb,)

@interp.impl(qubit.Apply)
def apply(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.Apply):
qubits: list[PyQrackQubit] = [frame.get(qbit) for qbit in stmt.qubits]
Expand Down
1 change: 1 addition & 0 deletions src/bloqade/squin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
_typeinfer as _typeinfer,
)
from .groups import wired as wired, kernel as kernel
from .stdlib.qubit import new as new
from .stdlib.simple import (
h as h,
s as s,
Expand Down
11 changes: 11 additions & 0 deletions src/bloqade/squin/analysis/address_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,14 @@ def new(
addr = AddressReg(range(interp_.next_address, interp_.next_address + n_qubits))
interp_.next_address += n_qubits
return (addr,)

@interp.impl(qubit.NewQubit)
def new_qubit(
self,
interp_: AddressAnalysis,
frame: ForwardFrame[Address],
stmt: qubit.NewQubit,
):
addr = AddressQubit(interp_.next_address)
interp_.next_address += 1
return (addr,)
16 changes: 16 additions & 0 deletions src/bloqade/squin/qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ class New(ir.Statement):
result: ir.ResultValue = info.result(ilist.IListType[QubitType, types.Any])


@statement(dialect=dialect)
class NewQubit(ir.Statement):
traits = frozenset({lowering.FromPythonCall()})
result: ir.ResultValue = info.result(QubitType)


@statement(dialect=dialect)
class Apply(ir.Statement):
traits = frozenset({lowering.FromPythonCall()})
Expand Down Expand Up @@ -106,6 +112,16 @@ def new(n_qubits: int) -> ilist.IList[Qubit, Any]:
...


@wraps(NewQubit)
def new_qubit() -> Qubit:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename NewQubit -> New

"""Create a new qubit.

Returns:
Qubit: A new qubit.
"""
...


@wraps(ApplyAny)
def apply(operator: Op, *qubits: Qubit) -> None:
"""Apply an operator to qubits. The number of qubit arguments must match the
Expand Down
21 changes: 21 additions & 0 deletions src/bloqade/squin/stdlib/qubit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from kirin.dialects import ilist

from .. import kernel
from ..qubit import new_qubit


@kernel(typeinfer=True)
def new(n_qubits: int):
"""Create a new list of qubits.

Args:
n_qubits(int): The number of qubits to create.

Returns:
(ilist.IList[Qubit, n_qubits]) A list of qubits.
"""

def _new(qid: int):
return new_qubit()

return ilist.map(_new, ilist.range(n_qubits))
25 changes: 24 additions & 1 deletion test/analysis/address/test_qubit_analysis.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
from util import collect_address_types

from bloqade.squin import op, qubit, kernel
from bloqade.squin import op, new, qubit, kernel
from bloqade.analysis import address

# test tuple and indexing
Expand Down Expand Up @@ -126,3 +126,26 @@ def main():

address_analysis = address.AddressAnalysis(main.dialects)
address_analysis.run_analysis(main, no_raise=False)


def test_new_qubit():
@kernel
def main():
return qubit.new_qubit()

address_analysis = address.AddressAnalysis(main.dialects)
_, result = address_analysis.run_analysis(main, no_raise=False)
assert result == address.AddressQubit(0)


@pytest.mark.xfail # fails due to ilist.map not being handled correctly
def test_new_stdlib():
@kernel
def main():
return new(10)

address_analysis = address.AddressAnalysis(main.dialects)
_, result = address_analysis.run_analysis(main, no_raise=False)
assert (
result == address.AnyAddress()
) # TODO: should be AddressTuple with AddressQubits
9 changes: 4 additions & 5 deletions test/native/test_stdlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,15 @@
from kirin import ir
from kirin.dialects import ilist

from bloqade import native
from bloqade.squin import qubit
from bloqade import squin, native
from bloqade.pyqrack import DynamicMemorySimulator


def test_ghz():

@native.kernel(typeinfer=True, fold=True)
def main():
qreg = qubit.new(4)
qreg = squin.new(4)

native.h(qreg[0])

Expand Down Expand Up @@ -57,10 +56,10 @@ def main():
(native.s_dag, [1.0, 0.0]),
],
)
def test_1q_gate(gate_func: ir.Method[[qubit.Qubit], None], expected: Any):
def test_1q_gate(gate_func: ir.Method, expected: Any):
@native.kernel
def main():
q = qubit.new(1)
q = squin.new(1)
gate_func(q[0])

sv = DynamicMemorySimulator().state_vector(main)
Expand Down