-
Notifications
You must be signed in to change notification settings - Fork 1
Simplyfing qubit.new
#518
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Simplyfing qubit.new
#518
Changes from 23 commits
4f4088a
acfa798
fe677ea
738bfe8
8396e77
d9eda24
bee7cf6
d90907f
9257a32
9bc330f
a51fc40
0807726
a3b7124
c8ec9a8
78e36ed
0edcf0f
c28ff2c
0944d60
ac72611
28d1782
a2564d5
126b0e5
aa20f7d
0e16e43
6829892
a716da4
ba8acdc
7c53555
15366cb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,11 +6,11 @@ | |
from kirin import ir, types, interp | ||
from kirin.emit import EmitABC, EmitError, EmitFrame | ||
from kirin.interp import MethodTable, impl | ||
from kirin.passes import inline | ||
from kirin.dialects import func | ||
from kirin.dialects import py, func | ||
from typing_extensions import Self | ||
|
||
from bloqade.squin import kernel | ||
from bloqade.rewrite.passes import AggressiveUnroll | ||
|
||
|
||
def emit_circuit( | ||
|
@@ -28,7 +28,7 @@ def emit_circuit( | |
Keyword Args: | ||
circuit_qubits (Sequence[cirq.Qid] | None): | ||
A list of qubits to use as the qubits in the circuit. Defaults to None. | ||
If this is None, then `cirq.LineQubit`s are inserted for every `squin.qubit.new` | ||
If this is None, then `cirq.LineQubit`s are inserted for every `squin.qalloc` | ||
statement in the order they appear inside the kernel. | ||
**Note**: If a list of qubits is provided, make sure that there is a sufficient | ||
number of qubits for the resulting circuit. | ||
|
@@ -48,7 +48,7 @@ def emit_circuit( | |
|
||
@squin.kernel | ||
def main(): | ||
q = squin.qubit.new(2) | ||
q = squin.qalloc(2) | ||
squin.h(q[0]) | ||
squin.cx(q[0], q[1]) | ||
|
||
|
@@ -74,8 +74,10 @@ def entangle(q: ilist.IList[squin.qubit.Qubit, Literal[2]]): | |
|
||
@squin.kernel | ||
def main(): | ||
q = squin.qubit.new(2) | ||
entangle(q) | ||
q = squin.qalloc(2) | ||
q2 = squin.qalloc(3) | ||
squin.cx(q[1], q2[2]) | ||
|
||
|
||
# custom list of qubits on grid | ||
qubits = [cirq.GridQubit(i, i+1) for i in range(5)] | ||
|
@@ -112,10 +114,44 @@ def main(): | |
|
||
emitter = EmitCirq(qubits=circuit_qubits) | ||
|
||
mt_ = mt.similar(mt.dialects) | ||
inline.InlinePass(mt_.dialects).fixpoint(mt_) | ||
symbol_op_trait = mt.code.get_trait(ir.SymbolOpInterface) | ||
if (symbol_op_trait := mt.code.get_trait(ir.SymbolOpInterface)) is None: | ||
raise EmitError("The method is not a symbol, cannot emit circuit!") | ||
|
||
sym_name = symbol_op_trait.get_sym_name(mt.code).unwrap() | ||
|
||
if (signature_trait := mt.code.get_trait(ir.HasSignature)) is None: | ||
raise EmitError( | ||
f"The method {sym_name} does not have a signature, cannot emit circuit!" | ||
) | ||
|
||
signature = signature_trait.get_signature(mt.code) | ||
new_signature = func.Signature(inputs=(), output=signature.output) | ||
|
||
callable_region = mt.callable_region.clone() | ||
entry_block = callable_region.blocks[0] | ||
args_ssa = list(entry_block.args) | ||
first_stmt = entry_block.first_stmt | ||
|
||
assert first_stmt is not None, "Method has no statements!" | ||
if len(args_ssa) - 1 != len(args): | ||
raise EmitError( | ||
f"The method {sym_name} takes {len(args_ssa)} arguments, but you passed in {len(args)} via the `args` keyword!" | ||
) | ||
|
||
for arg, arg_ssa in zip(args, args_ssa[1:], strict=True): | ||
(value := py.Constant(arg)).insert_before(first_stmt) | ||
arg_ssa.replace_by(value.result) | ||
entry_block.args.delete(arg_ssa) | ||
|
||
new_func = func.Function( | ||
sym_name=sym_name, body=callable_region, signature=new_signature | ||
) | ||
mt_ = ir.Method(None, None, sym_name, [], mt.dialects, new_func) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I admit to being a bit lost here. Why is it necessary to construct a new method like this here? Why is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The signature is different, I removed all the arguments of the function. |
||
|
||
return emitter.run(mt_, args=args) | ||
AggressiveUnroll(mt_.dialects).fixpoint(mt_) | ||
mt_.print(hint="const") | ||
return emitter.run(mt_, args=()) | ||
|
||
|
||
@dataclass | ||
|
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,7 +9,7 @@ | |
|
||
from typing import Any, overload | ||
|
||
from kirin import ir, types, lowering | ||
from kirin import ir, types, interp, lowering | ||
from kirin.decl import info, statement | ||
from kirin.dialects import ilist | ||
from kirin.lowering import wraps | ||
|
@@ -22,8 +22,7 @@ | |
@statement(dialect=dialect) | ||
class New(ir.Statement): | ||
traits = frozenset({lowering.FromPythonCall()}) | ||
n_qubits: ir.SSAValue = info.argument(types.Int) | ||
result: ir.ResultValue = info.result(ilist.IListType[QubitType, types.Any]) | ||
result: ir.ResultValue = info.result(QubitType) | ||
|
||
|
||
@statement(dialect=dialect) | ||
|
@@ -44,13 +43,16 @@ class MeasureQubit(ir.Statement): | |
result: ir.ResultValue = info.result(MeasurementResultType) | ||
|
||
|
||
Len = types.TypeVar("Len") | ||
|
||
|
||
@statement(dialect=dialect) | ||
class MeasureQubitList(ir.Statement): | ||
name = "measure.qubit.list" | ||
|
||
traits = frozenset({lowering.FromPythonCall()}) | ||
qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType]) | ||
result: ir.ResultValue = info.result(ilist.IListType[MeasurementResultType]) | ||
qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, Len]) | ||
result: ir.ResultValue = info.result(ilist.IListType[MeasurementResultType, Len]) | ||
|
||
|
||
@statement(dialect=dialect) | ||
|
@@ -75,14 +77,11 @@ class Reset(ir.Statement): | |
|
||
# NOTE: no dependent types in Python, so we have to mark it Any... | ||
@wraps(New) | ||
def new(n_qubits: int) -> ilist.IList[Qubit, Any]: | ||
"""Create a new list of qubits. | ||
|
||
Args: | ||
n_qubits(int): The number of qubits to create. | ||
def new() -> Qubit: | ||
"""Create a new qubit. | ||
|
||
Returns: | ||
(ilist.IList[Qubit, n_qubits]) A list of qubits. | ||
Qubit: A new qubit. | ||
""" | ||
... | ||
|
||
|
@@ -117,3 +116,20 @@ def get_qubit_id(qubit: Qubit) -> int: ... | |
|
||
@wraps(MeasurementId) | ||
def get_measurement_id(measurement: MeasurementResult) -> int: ... | ||
|
||
|
||
# TODO: investigate why this is needed to get type inference to be correct. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Might be worth considering this as part of #549, since depending on the changes there, cc @johnzl-777 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This PR is a bit old now isn't it? I'm all for getting this in first and then I can make the necessary tweaks in a dedicated PR for #549 . The request for changes might take a bit considering Phillip's in Munich for the Munich Quantum Software Forum There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The comment here is more on the Kirin side of things. In principle the statements typing should just work as it but for some reason it wasn't so I added an explicit type inference method |
||
@dialect.register(key="typeinfer") | ||
class __TypeInfer(interp.MethodTable): | ||
@interp.impl(MeasureQubitList) | ||
def measure_list( | ||
self, _interp, frame: interp.AbstractFrame, stmt: MeasureQubitList | ||
): | ||
qubit_type = frame.get(stmt.qubits) | ||
|
||
if isinstance(qubit_type, types.Generic): | ||
len_type = qubit_type.vars[1] | ||
else: | ||
len_type = types.Any | ||
|
||
return (ilist.IListType[MeasurementResultType, len_type],) |
Uh oh!
There was an error while loading. Please reload this page.