Skip to content
10 changes: 6 additions & 4 deletions src/bloqade/cirq_utils/emit/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def emit_circuit(
Keyword Args:
circuit_qubits (Sequence[cirq.Qid] | None):
A list of qubits to use as the qubits in the circuit. Defaults to None.
If this is None, then `cirq.LineQubit`s are inserted for every `squin.qubit.new`
If this is None, then `cirq.LineQubit`s are inserted for every `squin.qalloc`
statement in the order they appear inside the kernel.
**Note**: If a list of qubits is provided, make sure that there is a sufficient
number of qubits for the resulting circuit.
Expand All @@ -48,7 +48,7 @@ def emit_circuit(

@squin.kernel
def main():
q = squin.qubit.new(2)
q = squin.qalloc(2)
squin.h(q[0])
squin.cx(q[0], q[1])

Expand All @@ -74,8 +74,10 @@ def entangle(q: ilist.IList[squin.qubit.Qubit, Literal[2]]):

@squin.kernel
def main():
q = squin.qubit.new(2)
entangle(q)
q = squin.qalloc(2)
q2 = squin.qalloc(3)
squin.cx(q[1], q2[2])


# custom list of qubits on grid
qubits = [cirq.GridQubit(i, i+1) for i in range(5)]
Expand Down
14 changes: 5 additions & 9 deletions src/bloqade/cirq_utils/emit/qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,15 @@
class EmitCirqQubitMethods(MethodTable):
@impl(qubit.New)
def new(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.New):
n_qubits = frame.get(stmt.n_qubits)

if frame.qubits is not None:
cirq_qubits = tuple(
frame.qubits[i + frame.qubit_index] for i in range(n_qubits)
)
cirq_qubit = frame.qubits[frame.qubit_index]
else:
cirq_qubits = tuple(
cirq.LineQubit(i + frame.qubit_index) for i in range(n_qubits)
)
cirq_qubit = cirq.LineQubit(frame.qubit_index)

frame.qubit_index += n_qubits
return (cirq_qubits,)
frame.has_allocations = True
frame.qubit_index += 1
return (cirq_qubit,)

@impl(qubit.Apply)
def apply(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.Apply):
Expand Down
8 changes: 5 additions & 3 deletions src/bloqade/cirq_utils/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from kirin.rewrite import Walk, CFGCompactify
from kirin.dialects import py, scf, func, ilist

from bloqade.squin import gate, noise, qubit, kernel
from bloqade.squin import gate, noise, qubit, kernel, qalloc


def load_circuit(
Expand Down Expand Up @@ -92,7 +92,7 @@ def load_circuit(
@squin.kernel
def main():
qreg = get_entangled_qubits()
qreg2 = squin.qubit.new(1)
qreg2 = squin.qalloc(1)
entangle_qubits([qreg[1], qreg2[0]])
return squin.qubit.measure(qreg2)
```
Expand Down Expand Up @@ -254,7 +254,9 @@ def run(
# NOTE: create a new register of appropriate size
n_qubits = len(self.qreg_index)
n = frame.push(py.Constant(n_qubits))
self.qreg = frame.push(qubit.New(n_qubits=n.result)).result
self.qreg = frame.push(
func.Invoke((n.result,), callee=qalloc, kwargs=())
).result

self.visit(state, stmt)

Expand Down
15 changes: 6 additions & 9 deletions src/bloqade/pyqrack/squin/qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,12 @@
@qubit.dialect.register(key="pyqrack")
class PyQrackMethods(interp.MethodTable):
@interp.impl(qubit.New)
def new(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.New):
n_qubits: int = frame.get(stmt.n_qubits)
qreg = ilist.IList(
[
PyQrackQubit(i, interp.memory.sim_reg, QubitState.Active)
for i in interp.memory.allocate(n_qubits=n_qubits)
]
)
return (qreg,)
def new_qubit(
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.New
):
(addr,) = interp.memory.allocate(1)
qb = PyQrackQubit(addr, interp.memory.sim_reg, QubitState.Active)
return (qb,)

@interp.impl(qubit.Apply)
def apply(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.Apply):
Expand Down
2 changes: 1 addition & 1 deletion src/bloqade/squin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
qubit as qubit,
analysis as analysis,
lowering as lowering,
_typeinfer as _typeinfer,
)
from .groups import wired as wired, kernel as kernel
from .stdlib.qubit import qalloc as qalloc
from .stdlib.simple import (
h as h,
s as s,
Expand Down
20 changes: 0 additions & 20 deletions src/bloqade/squin/_typeinfer.py

This file was deleted.

9 changes: 3 additions & 6 deletions src/bloqade/squin/analysis/address_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from bloqade.analysis.address.lattice import (
Address,
AddressReg,
AddressWire,
AddressQubit,
)
Expand Down Expand Up @@ -57,15 +56,13 @@ def apply(
@qubit.dialect.register(key="qubit.address")
class SquinQubitMethodTable(interp.MethodTable):

# This can be treated like a QRegNew impl
@interp.impl(qubit.New)
def new(
def new_qubit(
self,
interp_: AddressAnalysis,
frame: ForwardFrame[Address],
stmt: qubit.New,
):
n_qubits = interp_.get_const_value(int, stmt.n_qubits)
addr = AddressReg(range(interp_.next_address, interp_.next_address + n_qubits))
interp_.next_address += n_qubits
addr = AddressQubit(interp_.next_address)
interp_.next_address += 1
return (addr,)
3 changes: 2 additions & 1 deletion src/bloqade/squin/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,15 @@ def run_pass(method: ir.Method, *, fold=True, typeinfer=True):
py_mult_to_mult_pass(method)

if typeinfer:
typeinfer_pass(method)
typeinfer_pass(method) # infer types before desugaring
desugar_pass.rewrite(method.code)

ilist_desugar_pass(method)

if typeinfer:
typeinfer_pass(method) # fix types after desugaring
method.verify_type()
# method.print()

return run_pass

Expand Down
16 changes: 6 additions & 10 deletions src/bloqade/squin/qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@
@statement(dialect=dialect)
class New(ir.Statement):
traits = frozenset({lowering.FromPythonCall()})
n_qubits: ir.SSAValue = info.argument(types.Int)
result: ir.ResultValue = info.result(ilist.IListType[QubitType, types.Any])
result: ir.ResultValue = info.result(QubitType)


@statement(dialect=dialect)
Expand Down Expand Up @@ -94,14 +93,11 @@ class MeasurementId(ir.Statement):

# NOTE: no dependent types in Python, so we have to mark it Any...
@wraps(New)
def new(n_qubits: int) -> ilist.IList[Qubit, Any]:
"""Create a new list of qubits.

Args:
n_qubits(int): The number of qubits to create.
def new() -> Qubit:
"""Create a new qubit.

Returns:
(ilist.IList[Qubit, n_qubits]) A list of qubits.
Qubit: A new qubit.
"""
...

Expand Down Expand Up @@ -164,8 +160,8 @@ def broadcast(operator: Op, *qubits: ilist.IList[Qubit, OpSize] | list[Qubit]) -

@squin.kernel
def ghz():
controls = squin.qubit.new(4)
targets = squin.qubit.new(4)
controls = squin.qalloc(4)
targets = squin.qalloc(4)

h = squin.op.h()
squin.qubit.broadcast(h, controls)
Expand Down
25 changes: 25 additions & 0 deletions src/bloqade/squin/stdlib/qubit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from typing import Any

from kirin.dialects import ilist

from .. import qubit, kernel


@kernel(typeinfer=True)
def qalloc(n_qubits: int) -> ilist.IList[qubit.Qubit, Any]:
"""Allocate a new list of qubits.

Args:
n_qubits(int): The number of qubits to create.

Returns:
(ilist.IList[Qubit, n_qubits]) A list of qubits.
"""

def _new(qid: int) -> qubit.Qubit:
return qubit.new()

return ilist.map(_new, ilist.range(n_qubits))


qalloc.print()
37 changes: 30 additions & 7 deletions test/analysis/address/test_qubit_analysis.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
from util import collect_address_types

from bloqade.squin import op, qubit, kernel
from bloqade.squin import op, qubit, kernel, qalloc
from bloqade.analysis import address

# test tuple and indexing
Expand All @@ -11,8 +11,8 @@ def test_tuple_address():

@kernel
def test():
q1 = qubit.new(5)
q2 = qubit.new(10)
q1 = qalloc(5)
q2 = qalloc(10)
qubit.broadcast(op.y(), q1)
qubit.apply(op.x(), q2[2]) # desugar creates a new ilist here
# natural to expect two AddressTuple types
Expand All @@ -37,7 +37,7 @@ def test_get_item():

@kernel
def test():
q = qubit.new(5)
q = qalloc(5)
qubit.broadcast(op.y(), q)
x = (q[0], q[3]) # -> AddressTuple(AddressQubit, AddressQubit)
y = q[2] # getitem on ilist # -> AddressQubit
Expand Down Expand Up @@ -66,7 +66,7 @@ def extract_qubits(qubits):

@kernel
def test():
q = qubit.new(5)
q = qalloc(5)
qubit.broadcast(op.y(), q)
return extract_qubits(q)

Expand All @@ -84,7 +84,7 @@ def test_slice():

@kernel
def main():
q = qubit.new(4)
q = qalloc(4)
# get the middle qubits out and apply to them
sub_q = q[1:3]
qubit.broadcast(op.x(), sub_q)
Expand Down Expand Up @@ -117,7 +117,7 @@ def main():
def test_for_loop_idx():
@kernel
def main():
q = qubit.new(3)
q = qalloc(3)
x = op.x()
for i in range(3):
qubit.apply(x, [q[i]])
Expand All @@ -126,3 +126,26 @@ def main():

address_analysis = address.AddressAnalysis(main.dialects)
address_analysis.run_analysis(main, no_raise=False)


def test_new_qubit():
@kernel
def main():
return qalloc()

address_analysis = address.AddressAnalysis(main.dialects)
_, result = address_analysis.run_analysis(main, no_raise=False)
assert result == address.AddressQubit(0)


@pytest.mark.xfail # fails due to ilist.map not being handled correctly
def test_new_stdlib():
@kernel
def main():
return qalloc(10)

address_analysis = address.AddressAnalysis(main.dialects)
_, result = address_analysis.run_analysis(main, no_raise=False)
assert (
result == address.AnyAddress()
) # TODO: should be AddressTuple with AddressQubits
Loading
Loading