Skip to content
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
4f4088a
adding new qubit and stdlib to allocate new qubits
weinbe58 Sep 30, 2025
acfa798
Adding pyqrack interpreter and using that in test
weinbe58 Sep 30, 2025
fe677ea
Adding test for new stdlib function
weinbe58 Sep 30, 2025
738bfe8
marking test as expected to fail
weinbe58 Oct 1, 2025
8396e77
fixing incorrect usage xfail
weinbe58 Oct 1, 2025
d9eda24
removing print
weinbe58 Oct 6, 2025
bee7cf6
Merge branch 'main' into phil/508-simplifying-qubitnew
weinbe58 Oct 7, 2025
d90907f
merging stashed changes
weinbe58 Oct 7, 2025
9257a32
WIP fixing tests
weinbe58 Oct 7, 2025
9bc330f
merging main
weinbe58 Oct 8, 2025
a51fc40
Fixing bug in test
weinbe58 Oct 8, 2025
0807726
fixing some tests
weinbe58 Oct 8, 2025
a3b7124
merging main
weinbe58 Oct 8, 2025
c8ec9a8
merging main
weinbe58 Oct 16, 2025
78e36ed
Fixing some tests
weinbe58 Oct 16, 2025
0edcf0f
removing print
weinbe58 Oct 16, 2025
c28ff2c
WIP: trying to fix cirq emit
weinbe58 Oct 16, 2025
0944d60
fixing test adding fixedpoint to unroll
weinbe58 Oct 16, 2025
ac72611
pin lower bound on kirin
weinbe58 Oct 16, 2025
28d1782
fixing test
weinbe58 Oct 16, 2025
a2564d5
fixing tests
weinbe58 Oct 17, 2025
126b0e5
fixing potential issue with unroll pass
weinbe58 Oct 17, 2025
aa20f7d
fixing last test
weinbe58 Oct 17, 2025
0e16e43
Update test/pyqrack/runtime/test_qrack.py
weinbe58 Oct 17, 2025
6829892
Update src/bloqade/cirq_utils/emit/base.py
weinbe58 Oct 17, 2025
a716da4
removing print
weinbe58 Oct 17, 2025
ba8acdc
Merge branch 'phil/508-simplifying-qubitnew' of https://github.com/Qu…
weinbe58 Oct 17, 2025
7c53555
removing print
weinbe58 Oct 17, 2025
15366cb
Update groups.py
weinbe58 Oct 18, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ requires-python = ">=3.10"
dependencies = [
"numpy>=1.22.0",
"scipy>=1.13.1",
"kirin-toolchain~=0.17.26",
"kirin-toolchain~=0.17.30",
"rich>=13.9.4",
"pydantic>=1.3.0,<2.11.0",
"pandas>=2.2.3",
Expand Down
54 changes: 45 additions & 9 deletions src/bloqade/cirq_utils/emit/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
from kirin import ir, types, interp
from kirin.emit import EmitABC, EmitError, EmitFrame
from kirin.interp import MethodTable, impl
from kirin.passes import inline
from kirin.dialects import func
from kirin.dialects import py, func
from typing_extensions import Self

from bloqade.squin import kernel
from bloqade.rewrite.passes import AggressiveUnroll


def emit_circuit(
Expand All @@ -28,7 +28,7 @@ def emit_circuit(
Keyword Args:
circuit_qubits (Sequence[cirq.Qid] | None):
A list of qubits to use as the qubits in the circuit. Defaults to None.
If this is None, then `cirq.LineQubit`s are inserted for every `squin.qubit.new`
If this is None, then `cirq.LineQubit`s are inserted for every `squin.qalloc`
statement in the order they appear inside the kernel.
**Note**: If a list of qubits is provided, make sure that there is a sufficient
number of qubits for the resulting circuit.
Expand All @@ -48,7 +48,7 @@ def emit_circuit(

@squin.kernel
def main():
q = squin.qubit.new(2)
q = squin.qalloc(2)
squin.h(q[0])
squin.cx(q[0], q[1])

Expand All @@ -74,8 +74,10 @@ def entangle(q: ilist.IList[squin.qubit.Qubit, Literal[2]]):

@squin.kernel
def main():
q = squin.qubit.new(2)
entangle(q)
q = squin.qalloc(2)
q2 = squin.qalloc(3)
squin.cx(q[1], q2[2])


# custom list of qubits on grid
qubits = [cirq.GridQubit(i, i+1) for i in range(5)]
Expand Down Expand Up @@ -112,10 +114,44 @@ def main():

emitter = EmitCirq(qubits=circuit_qubits)

mt_ = mt.similar(mt.dialects)
inline.InlinePass(mt_.dialects).fixpoint(mt_)
symbol_op_trait = mt.code.get_trait(ir.SymbolOpInterface)
if (symbol_op_trait := mt.code.get_trait(ir.SymbolOpInterface)) is None:
raise EmitError("The method is not a symbol, cannot emit circuit!")

sym_name = symbol_op_trait.get_sym_name(mt.code).unwrap()

if (signature_trait := mt.code.get_trait(ir.HasSignature)) is None:
raise EmitError(
f"The method {sym_name} does not have a signature, cannot emit circuit!"
)

signature = signature_trait.get_signature(mt.code)
new_signature = func.Signature(inputs=(), output=signature.output)

callable_region = mt.callable_region.clone()
entry_block = callable_region.blocks[0]
args_ssa = list(entry_block.args)
first_stmt = entry_block.first_stmt

assert first_stmt is not None, "Method has no statements!"
if len(args_ssa) - 1 != len(args):
raise EmitError(
f"The method {sym_name} takes {len(args_ssa)} arguments, but you passed in {len(args)} via the `args` keyword!"
)

for arg, arg_ssa in zip(args, args_ssa[1:], strict=True):
(value := py.Constant(arg)).insert_before(first_stmt)
arg_ssa.replace_by(value.result)
entry_block.args.delete(arg_ssa)

new_func = func.Function(
sym_name=sym_name, body=callable_region, signature=new_signature
)
mt_ = ir.Method(None, None, sym_name, [], mt.dialects, new_func)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I admit to being a bit lost here. Why is it necessary to construct a new method like this here? Why is mt_ = mt.similar(mt.dialects) not sufficient?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The signature is different, I removed all the arguments of the function.


return emitter.run(mt_, args=args)
AggressiveUnroll(mt_.dialects).fixpoint(mt_)
mt_.print(hint="const")
return emitter.run(mt_, args=())


@dataclass
Expand Down
14 changes: 4 additions & 10 deletions src/bloqade/cirq_utils/emit/qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,13 @@
class EmitCirqQubitMethods(MethodTable):
@impl(qubit.New)
def new(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.New):
n_qubits = frame.get(stmt.n_qubits)

if frame.qubits is not None:
cirq_qubits = tuple(
frame.qubits[i + frame.qubit_index] for i in range(n_qubits)
)
cirq_qubit = frame.qubits[frame.qubit_index]
else:
cirq_qubits = tuple(
cirq.LineQubit(i + frame.qubit_index) for i in range(n_qubits)
)
cirq_qubit = cirq.LineQubit(frame.qubit_index)

frame.qubit_index += n_qubits
return (cirq_qubits,)
frame.qubit_index += 1
return (cirq_qubit,)

@impl(qubit.MeasureQubit)
def measure_qubit(
Expand Down
25 changes: 20 additions & 5 deletions src/bloqade/cirq_utils/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from kirin.rewrite import Walk, CFGCompactify
from kirin.dialects import py, scf, func, ilist

from bloqade.squin import gate, noise, qubit, kernel
from bloqade.squin import gate, noise, qubit, kernel, qalloc


def load_circuit(
Expand Down Expand Up @@ -92,7 +92,7 @@ def load_circuit(
@squin.kernel
def main():
qreg = get_entangled_qubits()
qreg2 = squin.qubit.new(1)
qreg2 = squin.qalloc(1)
entangle_qubits([qreg[1], qreg2[0]])
return squin.qubit.measure(qreg2)
```
Expand Down Expand Up @@ -142,14 +142,19 @@ def main():
body=body,
)

return ir.Method(
mt = ir.Method(
mod=None,
py_func=None,
sym_name=kernel_name,
arg_names=arg_names,
dialects=dialects,
code=code,
)
mt.print()
assert (run_pass := kernel.run_pass) is not None
run_pass(mt, typeinfer=True)

return mt


CirqNode = (
Expand Down Expand Up @@ -254,7 +259,9 @@ def run(
# NOTE: create a new register of appropriate size
n_qubits = len(self.qreg_index)
n = frame.push(py.Constant(n_qubits))
self.qreg = frame.push(qubit.New(n_qubits=n.result)).result
self.qreg = frame.push(
func.Invoke((n.result,), callee=qalloc, kwargs=())
).result

self.visit(state, stmt)

Expand Down Expand Up @@ -382,8 +389,16 @@ def bool_op_or(x: bool, y: bool) -> bool:
# NOTE: remove stmt from parent block
then_stmt.detach()
then_body = ir.Block((then_stmt,))
then_body.args.append_from(types.Bool, name="cond")
then_body.stmts.append(scf.Yield())

else_body = ir.Block(())
else_body.args.append_from(types.Bool, name="cond")
else_body.stmts.append(scf.Yield())

return state.current_frame.push(scf.IfElse(condition, then_body=then_body))
return state.current_frame.push(
scf.IfElse(condition, then_body=then_body, else_body=else_body)
)

def visit_MeasurementGate(
self, state: lowering.State[cirq.Circuit], node: cirq.GateOperation
Expand Down
15 changes: 6 additions & 9 deletions src/bloqade/pyqrack/squin/qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,12 @@
@qubit.dialect.register(key="pyqrack")
class PyQrackMethods(interp.MethodTable):
@interp.impl(qubit.New)
def new(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.New):
n_qubits: int = frame.get(stmt.n_qubits)
qreg = ilist.IList(
[
PyQrackQubit(i, interp.memory.sim_reg, QubitState.Active)
for i in interp.memory.allocate(n_qubits=n_qubits)
]
)
return (qreg,)
def new_qubit(
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.New
):
(addr,) = interp.memory.allocate(1)
qb = PyQrackQubit(addr, interp.memory.sim_reg, QubitState.Active)
return (qb,)

def _measure_qubit(self, qbit: PyQrackQubit, interp: PyQrackInterpreter):
if qbit.is_active():
Expand Down
3 changes: 2 additions & 1 deletion src/bloqade/rewrite/passes/aggressive_unroll.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def unsafe_run(self, mt: Method) -> RewriteResult:
InlineGetField(),
InlineGetItem(),
ilist.rewrite.InlineGetItem(),
ilist.rewrite.FlattenAdd(),
ilist.rewrite.HintLen(),
)
result = Fixpoint(Walk(rule)).rewrite(mt.code).join(result)
Expand Down Expand Up @@ -68,7 +69,7 @@ def unsafe_run(self, mt: Method) -> RewriteResult:
.rewrite(mt.code)
.join(result)
)
result = self.typeinfer.unsafe_run(mt).join(result)
self.typeinfer.unsafe_run(mt)
result = self.fold.unsafe_run(mt).join(result)
result = Walk(Inline(self.inline_heuristic)).rewrite(mt.code).join(result)
result = Walk(Fixpoint(CFGCompactify())).rewrite(mt.code).join(result)
Expand Down
2 changes: 1 addition & 1 deletion src/bloqade/squin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
noise as noise,
qubit as qubit,
analysis as analysis,
_typeinfer as _typeinfer,
)
from .groups import kernel as kernel
from .stdlib.qubit import qalloc as qalloc
from .stdlib.simple import (
h as h,
s as s,
Expand Down
19 changes: 0 additions & 19 deletions src/bloqade/squin/_typeinfer.py

This file was deleted.

10 changes: 4 additions & 6 deletions src/bloqade/squin/analysis/address_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from bloqade.analysis.address.lattice import (
Address,
AddressReg,
AddressQubit,
)
from bloqade.analysis.address.analysis import AddressAnalysis

Expand All @@ -27,15 +27,13 @@
@qubit.dialect.register(key="qubit.address")
class SquinQubitMethodTable(interp.MethodTable):

# This can be treated like a QRegNew impl
@interp.impl(qubit.New)
def new(
def new_qubit(
self,
interp_: AddressAnalysis,
frame: ForwardFrame[Address],
stmt: qubit.New,
):
n_qubits = interp_.get_const_value(int, stmt.n_qubits)
addr = AddressReg(range(interp_.next_address, interp_.next_address + n_qubits))
interp_.next_address += n_qubits
addr = AddressQubit(interp_.next_address)
interp_.next_address += 1
return (addr,)
3 changes: 2 additions & 1 deletion src/bloqade/squin/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@ def run_pass(method: ir.Method, *, fold=True, typeinfer=True):
fold_pass.fixpoint(method)

if typeinfer:
typeinfer_pass(method)
typeinfer_pass(method) # infer types before desugaring
desugar_pass.rewrite(method.code)

ilist_desugar_pass(method)

if typeinfer:
typeinfer_pass(method) # fix types after desugaring
method.verify_type()
# method.print()

return run_pass
38 changes: 27 additions & 11 deletions src/bloqade/squin/qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from typing import Any, overload

from kirin import ir, types, lowering
from kirin import ir, types, interp, lowering
from kirin.decl import info, statement
from kirin.dialects import ilist
from kirin.lowering import wraps
Expand All @@ -22,8 +22,7 @@
@statement(dialect=dialect)
class New(ir.Statement):
traits = frozenset({lowering.FromPythonCall()})
n_qubits: ir.SSAValue = info.argument(types.Int)
result: ir.ResultValue = info.result(ilist.IListType[QubitType, types.Any])
result: ir.ResultValue = info.result(QubitType)


@statement(dialect=dialect)
Expand All @@ -44,13 +43,16 @@ class MeasureQubit(ir.Statement):
result: ir.ResultValue = info.result(MeasurementResultType)


Len = types.TypeVar("Len")


@statement(dialect=dialect)
class MeasureQubitList(ir.Statement):
name = "measure.qubit.list"

traits = frozenset({lowering.FromPythonCall()})
qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType])
result: ir.ResultValue = info.result(ilist.IListType[MeasurementResultType])
qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, Len])
result: ir.ResultValue = info.result(ilist.IListType[MeasurementResultType, Len])


@statement(dialect=dialect)
Expand All @@ -75,14 +77,11 @@ class Reset(ir.Statement):

# NOTE: no dependent types in Python, so we have to mark it Any...
@wraps(New)
def new(n_qubits: int) -> ilist.IList[Qubit, Any]:
"""Create a new list of qubits.

Args:
n_qubits(int): The number of qubits to create.
def new() -> Qubit:
"""Create a new qubit.

Returns:
(ilist.IList[Qubit, n_qubits]) A list of qubits.
Qubit: A new qubit.
"""
...

Expand Down Expand Up @@ -117,3 +116,20 @@ def get_qubit_id(qubit: Qubit) -> int: ...

@wraps(MeasurementId)
def get_measurement_id(measurement: MeasurementResult) -> int: ...


# TODO: investigate why this is needed to get type inference to be correct.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be worth considering this as part of #549, since depending on the changes there, MeasureQubit and MeasureQubitList may be consolidated.

cc @johnzl-777

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is a bit old now isn't it? I'm all for getting this in first and then I can make the necessary tweaks in a dedicated PR for #549 .

The request for changes might take a bit considering Phillip's in Munich for the Munich Quantum Software Forum

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment here is more on the Kirin side of things.

In principle the statements typing should just work as it but for some reason it wasn't so I added an explicit type inference method

@dialect.register(key="typeinfer")
class __TypeInfer(interp.MethodTable):
@interp.impl(MeasureQubitList)
def measure_list(
self, _interp, frame: interp.AbstractFrame, stmt: MeasureQubitList
):
qubit_type = frame.get(stmt.qubits)

if isinstance(qubit_type, types.Generic):
len_type = qubit_type.vars[1]
else:
len_type = types.Any

return (ilist.IListType[MeasurementResultType, len_type],)
3 changes: 2 additions & 1 deletion src/bloqade/squin/rewrite/wrap_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ class WrapAddressAnalysis(WrapAnalysis):
address_analysis: dict[ir.SSAValue, Address]

def wrap(self, value: ir.SSAValue) -> bool:
address_analysis_result = self.address_analysis[value]
if (address_analysis_result := self.address_analysis.get(value)) is None:
return False

if value.hints.get("address") is not None:
return False
Expand Down
Loading