Skip to content

Commit a2fa304

Browse files
david-plweinbe58
andauthored
Rework squin's types for rotations, controls and binary ops (#454)
Draft for the extended type system suggested in #403. FYI, @weinbe58. Using generics for the composite types with shorthand aliases works fine and skips the need to implement any custom type inference. --------- Co-authored-by: Phillip Weinberg <[email protected]>
1 parent f67fd44 commit a2fa304

File tree

3 files changed

+227
-13
lines changed

3 files changed

+227
-13
lines changed

src/bloqade/squin/op/stmts.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,19 @@
11
from kirin import ir, types, lowering
22
from kirin.decl import info, statement
33

4-
from .types import OpType, PauliOpType, PauliStringType
4+
from .types import (
5+
OpType,
6+
ROpType,
7+
XOpType,
8+
YOpType,
9+
ZOpType,
10+
KronType,
11+
MultType,
12+
PauliOpType,
13+
ControlOpType,
14+
PauliStringType,
15+
ControlledOpType,
16+
)
517
from .number import NumberType
618
from .traits import Unitary, HasSites, FixedSites, MaybeUnitary
719
from ._dialect import dialect
@@ -22,22 +34,28 @@ class CompositeOp(Operator):
2234
pass
2335

2436

37+
LhsType = types.TypeVar("Lhs", bound=OpType)
38+
RhsType = types.TypeVar("Rhs", bound=OpType)
39+
40+
2541
@statement
2642
class BinaryOp(CompositeOp):
27-
lhs: ir.SSAValue = info.argument(OpType)
28-
rhs: ir.SSAValue = info.argument(OpType)
43+
lhs: ir.SSAValue = info.argument(LhsType)
44+
rhs: ir.SSAValue = info.argument(RhsType)
2945

3046

3147
@statement(dialect=dialect)
3248
class Kron(BinaryOp):
3349
traits = frozenset({ir.Pure(), lowering.FromPythonCall(), MaybeUnitary()})
3450
is_unitary: bool = info.attribute(default=False)
51+
result: ir.ResultValue = info.result(KronType[LhsType, RhsType])
3552

3653

3754
@statement(dialect=dialect)
3855
class Mult(BinaryOp):
3956
traits = frozenset({ir.Pure(), lowering.FromPythonCall(), MaybeUnitary()})
4057
is_unitary: bool = info.attribute(default=False)
58+
result: ir.ResultValue = info.result(MultType[LhsType, RhsType])
4159

4260

4361
@statement(dialect=dialect)
@@ -59,15 +77,20 @@ class Scale(CompositeOp):
5977
class Control(CompositeOp):
6078
traits = frozenset({ir.Pure(), lowering.FromPythonCall(), MaybeUnitary()})
6179
is_unitary: bool = info.attribute(default=False)
62-
op: ir.SSAValue = info.argument(OpType)
80+
op: ir.SSAValue = info.argument(ControlledOpType)
6381
n_controls: int = info.attribute()
82+
result: ir.ResultValue = info.result(ControlOpType[ControlledOpType])
83+
84+
85+
RotationAxisType = types.TypeVar("RotationAxis", bound=OpType)
6486

6587

6688
@statement(dialect=dialect)
6789
class Rot(CompositeOp):
6890
traits = frozenset({ir.Pure(), lowering.FromPythonCall(), Unitary()})
69-
axis: ir.SSAValue = info.argument(OpType)
91+
axis: ir.SSAValue = info.argument(RotationAxisType)
7092
angle: ir.SSAValue = info.argument(types.Float)
93+
result: ir.ResultValue = info.result(ROpType[RotationAxisType])
7194

7295

7396
@statement(dialect=dialect)
@@ -184,17 +207,17 @@ def verify(self) -> None:
184207

185208
@statement(dialect=dialect)
186209
class X(PauliOp):
187-
pass
210+
result: ir.ResultValue = info.result(XOpType)
188211

189212

190213
@statement(dialect=dialect)
191214
class Y(PauliOp):
192-
pass
215+
result: ir.ResultValue = info.result(YOpType)
193216

194217

195218
@statement(dialect=dialect)
196219
class Z(PauliOp):
197-
pass
220+
result: ir.ResultValue = info.result(ZOpType)
198221

199222

200223
@statement(dialect=dialect)

src/bloqade/squin/op/types.py

Lines changed: 92 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import overload
1+
from typing import Generic, TypeVar, overload
22

33
from kirin import types
44

@@ -21,21 +21,108 @@ def __rmul__(self, other: complex) -> "Op":
2121
raise NotImplementedError("@ can only be used within a squin kernel program")
2222

2323

24+
OpType = types.PyClass(Op)
25+
26+
27+
class CompositeOp(Op):
28+
pass
29+
30+
31+
CompositeOpType = types.PyClass(CompositeOp)
32+
33+
LhsType = TypeVar("LhsType", bound=Op)
34+
RhsType = TypeVar("RhsType", bound=Op)
35+
36+
37+
class BinaryOp(Op, Generic[LhsType, RhsType]):
38+
lhs: LhsType
39+
rhs: RhsType
40+
41+
42+
BinaryOpType = types.Generic(BinaryOp, OpType, OpType)
43+
44+
45+
class Mult(BinaryOp[LhsType, RhsType]):
46+
pass
47+
48+
49+
MultType = types.Generic(Mult, OpType, OpType)
50+
51+
52+
class Kron(BinaryOp[LhsType, RhsType]):
53+
pass
54+
55+
56+
KronType = types.Generic(Kron, OpType, OpType)
57+
58+
2459
class MultiQubitPauliOp(Op):
2560
pass
2661

2762

63+
MultiQubitPauliOpType = types.PyClass(MultiQubitPauliOp)
64+
65+
2866
class PauliStringOp(MultiQubitPauliOp):
2967
pass
3068

3169

70+
PauliStringType = types.PyClass(PauliStringOp)
71+
72+
3273
class PauliOp(MultiQubitPauliOp):
3374
pass
3475

3576

36-
OpType = types.PyClass(Op)
37-
MultiQubitPauliOpType = types.PyClass(MultiQubitPauliOp)
38-
PauliStringType = types.PyClass(PauliStringOp)
3977
PauliOpType = types.PyClass(PauliOp)
4078

41-
NumOperators = types.TypeVar("NumOperators")
79+
80+
class XOp(PauliOp):
81+
pass
82+
83+
84+
XOpType = types.PyClass(XOp)
85+
86+
87+
class YOp(PauliOp):
88+
pass
89+
90+
91+
YOpType = types.PyClass(YOp)
92+
93+
94+
class ZOp(PauliOp):
95+
pass
96+
97+
98+
ZOpType = types.PyClass(ZOp)
99+
100+
101+
ControlledOp = TypeVar("ControlledOp", bound=Op)
102+
103+
104+
class ControlOp(CompositeOp, Generic[ControlledOp]):
105+
op: ControlledOp
106+
107+
108+
ControlledOpType = types.TypeVar("ControlledOp", bound=OpType)
109+
ControlOpType = types.Generic(ControlOp, ControlledOpType)
110+
CXOpType = ControlOpType[XOpType]
111+
CYOpType = ControlOpType[YOpType]
112+
CZOpType = ControlOpType[ZOpType]
113+
114+
RotationAxis = TypeVar("RotationAxis", bound=Op)
115+
116+
117+
class ROp(CompositeOp, Generic[RotationAxis]):
118+
axis: RotationAxis
119+
angle: float
120+
121+
122+
ROpType = types.Generic(ROp, OpType)
123+
RxOpType = ROpType[XOpType]
124+
RyOpType = ROpType[YOpType]
125+
RzOpType = ROpType[ZOpType]
126+
127+
128+
NumOperators = types.TypeVar("NumOperators", bound=types.Int)

test/squin/test_typeinfer.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from kirin import ir
22
from kirin.types import Any, Literal
3+
from kirin.dialects import func
34
from kirin.dialects.ilist import IListType
45
from kirin.analysis.typeinfer import TypeInference
56

@@ -69,3 +70,106 @@ def test():
6970

7071
assert [frame.entries[result] for result in results_at(test, 0, 3)] == [QubitType]
7172
assert [frame.entries[result] for result in results_at(test, 0, 5)] == [QubitType]
73+
74+
75+
def test_generic_rot():
76+
@squin.kernel(fold=False)
77+
def main():
78+
z = squin.op.z()
79+
squin.op.rot(axis=z, angle=0.123)
80+
81+
main.print()
82+
83+
for stmt in main.callable_region.blocks[0].stmts:
84+
if isinstance(stmt, squin.op.stmts.Rot):
85+
assert stmt.result.type.is_subseteq(squin.op.types.RzOpType)
86+
assert stmt.result.type.is_subseteq(squin.op.types.CompositeOpType)
87+
assert stmt.result.type.is_subseteq(squin.op.types.OpType)
88+
89+
90+
def test_generic_control():
91+
@squin.kernel(fold=False)
92+
def main():
93+
z = squin.op.z()
94+
squin.op.control(z, n_controls=1)
95+
squin.op.cz()
96+
97+
main.print()
98+
99+
for stmt in main.callable_region.blocks[0].stmts:
100+
if isinstance(stmt, (squin.op.stmts.Control, func.Invoke)):
101+
assert stmt.result.type.is_subseteq(squin.op.types.CZOpType)
102+
assert stmt.result.type.is_subseteq(squin.op.types.CompositeOpType)
103+
assert stmt.result.type.is_subseteq(squin.op.types.OpType)
104+
105+
106+
def test_mult():
107+
108+
@squin.kernel(fold=False)
109+
def main():
110+
rx = squin.op.rx(1.0)
111+
return rx * rx
112+
113+
main.print()
114+
115+
for stmt in main.callable_region.blocks[0].stmts:
116+
if isinstance(stmt, squin.op.stmts.Mult):
117+
assert stmt.result.type.is_subseteq(squin.op.types.OpType)
118+
assert stmt.result.type.is_subseteq(squin.op.types.BinaryOpType)
119+
assert stmt.result.type.is_subseteq(
120+
squin.op.types.MultType[
121+
squin.op.types.RxOpType, squin.op.types.RxOpType
122+
]
123+
)
124+
125+
@squin.kernel(fold=False)
126+
def main2():
127+
rx = squin.op.rx(1.0)
128+
rz = squin.op.rz(1.123)
129+
return rx * rz
130+
131+
for stmt in main2.callable_region.blocks[0].stmts:
132+
if isinstance(stmt, squin.op.stmts.Mult):
133+
assert stmt.result.type.is_subseteq(squin.op.types.OpType)
134+
assert stmt.result.type.is_subseteq(squin.op.types.BinaryOpType)
135+
assert stmt.result.type.is_subseteq(
136+
squin.op.types.MultType[
137+
squin.op.types.RxOpType, squin.op.types.RzOpType
138+
]
139+
)
140+
141+
142+
def test_kron():
143+
144+
@squin.kernel(fold=False)
145+
def main():
146+
rx = squin.op.rx(1.0)
147+
return squin.op.kron(rx, rx)
148+
149+
main.print()
150+
151+
for stmt in main.callable_region.blocks[0].stmts:
152+
if isinstance(stmt, squin.op.stmts.Mult):
153+
assert stmt.result.type.is_subseteq(squin.op.types.OpType)
154+
assert stmt.result.type.is_subseteq(squin.op.types.BinaryOpType)
155+
assert stmt.result.type.is_subseteq(
156+
squin.op.types.KronType[
157+
squin.op.types.RxOpType, squin.op.types.RxOpType
158+
]
159+
)
160+
161+
@squin.kernel(fold=False)
162+
def main2():
163+
rx = squin.op.rx(1.0)
164+
rz = squin.op.rz(1.123)
165+
return squin.op.kron(rx, rz)
166+
167+
for stmt in main2.callable_region.blocks[0].stmts:
168+
if isinstance(stmt, squin.op.stmts.Mult):
169+
assert stmt.result.type.is_subseteq(squin.op.types.OpType)
170+
assert stmt.result.type.is_subseteq(squin.op.types.BinaryOpType)
171+
assert stmt.result.type.is_subseteq(
172+
squin.op.types.KronType[
173+
squin.op.types.RxOpType, squin.op.types.RzOpType
174+
]
175+
)

0 commit comments

Comments
 (0)