Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/bloqade/qasm2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
emit as emit,
glob as glob,
parse as parse,
analysis as analysis,
dialects as dialects,
parallel as parallel,
)
Expand Down
1 change: 1 addition & 0 deletions src/bloqade/qasm2/analysis/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .validation.analysis import QASM2Validation as QASM2Validation
Empty file.
121 changes: 121 additions & 0 deletions src/bloqade/qasm2/analysis/validation/analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from typing import Any

from kirin import ir, interp
from kirin.lattice import EmptyLattice
from kirin.analysis import Forward
from kirin.dialects import scf
from kirin.validation import ValidationPass
from kirin.analysis.forward import ForwardFrame

from bloqade.qasm2.passes.unroll_if import DontLiftType


class _QASM2ValidationAnalysis(Forward[EmptyLattice]):
keys = ["qasm2.main.validation"]

lattice = EmptyLattice

def method_self(self, method: ir.Method) -> EmptyLattice:
return self.lattice.bottom()

def eval_fallback(
self, frame: ForwardFrame[EmptyLattice], node: ir.Statement
) -> tuple[EmptyLattice, ...]:
return tuple(self.lattice.bottom() for _ in range(len(node.results)))


@scf.dialect.register(key="qasm2.main.validation")
class __ScfMethods(interp.MethodTable):

@interp.impl(scf.IfElse)
def if_else(
self,
interp_: _QASM2ValidationAnalysis,
frame: ForwardFrame[EmptyLattice],
stmt: scf.IfElse,
):

# TODO: stmt.condition has to be based off a measurement

if len(stmt.then_body.blocks) > 1:
interp_.add_validation_error(
stmt,
ir.ValidationError(
stmt,
"Only single block is allowed in the then-body of an if-else statement!",
),
)

then_stmts = list(stmt.then_body.stmts())
if len(then_stmts) > 2:
interp_.add_validation_error(
stmt,
ir.ValidationError(
stmt,
"Only single statements are allowed inside the then-body of an if-else statement!",
),
)

if not isinstance(then_stmts[0], DontLiftType):
interp_.add_validation_error(
stmt,
ir.ValidationError(
stmt, f"Statement {then_stmts[0]} not allowed inside if clause!"
),
)

self.__validate_empty_yield(interp_, then_stmts[-1])

if len(stmt.else_body.blocks) > 1:
interp_.add_validation_error(
stmt,
ir.ValidationError(
stmt,
"Only single block is allowed in the else-body of an if-else statement!",
),
)

else_stmts = list(stmt.else_body.stmts())
if len(else_stmts) > 1:
interp_.add_validation_error(
stmt,
ir.ValidationError(stmt, "Non-empty else is not allowed in QASM2!"),
)

self.__validate_empty_yield(interp_, else_stmts[-1])

def __validate_empty_yield(
self, interp_: _QASM2ValidationAnalysis, stmt: ir.Statement
):
if not isinstance(stmt, scf.Yield):
interp_.add_validation_error(
stmt,
ir.ValidationError(
stmt, f"Expected scf.Yield terminator in if clause, got {stmt}"
),
)
elif len(stmt.values) > 0:
interp_.add_validation_error(
stmt, ir.ValidationError(stmt, "Cannot yield values from if statement!")
)

@interp.impl(scf.For)
def for_loop(
self,
interp_: _QASM2ValidationAnalysis,
frame: ForwardFrame[EmptyLattice],
stmt: scf.For,
):
interp_.add_validation_error(
stmt, ir.ValidationError(stmt, "Loops not supported in QASM2!")
)


class QASM2Validation(ValidationPass):
def name(self) -> str:
return "QASM2 validation"

def run(self, method: ir.Method) -> tuple[Any, list[ir.ValidationError]]:
analysis = _QASM2ValidationAnalysis(method.dialects)
frame, _ = analysis.run(method)
return frame, analysis.get_validation_errors()
4 changes: 4 additions & 0 deletions src/bloqade/qasm2/emit/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from rich.console import Console
from kirin.analysis import CallGraph
from kirin.dialects import ilist
from kirin.validation import ValidationSuite

from bloqade.qasm2.parse import ast, pprint
from bloqade.qasm2.passes.fold import QASM2Fold
Expand All @@ -14,6 +15,7 @@
from . import impls as impls # register the tables
from .gate import EmitQASM2Gate
from .main import EmitQASM2Main
from ..analysis import QASM2Validation


class QASM2:
Expand Down Expand Up @@ -114,6 +116,8 @@ def emit(self, entry: ir.Method) -> ast.MainProgram:
# rewrite parallel to uop
ParallelToUOp(dialects=entry.dialects)(entry)

ValidationSuite([QASM2Validation]).validate(entry).raise_if_invalid()

Py2QASM(entry.dialects)(entry)
target_main = EmitQASM2Main(self.main_target).initialize()
target_main.run(entry)
Expand Down
8 changes: 8 additions & 0 deletions src/bloqade/qasm2/groups.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from kirin import ir, passes
from kirin.prelude import structural_no_opt
from kirin.dialects import scf, func, ilist, ssacfg, lowering
from kirin.validation import ValidationSuite

from bloqade.qasm2.passes import UnrollIfs
from bloqade.qasm2.analysis import QASM2Validation
from bloqade.qasm2.dialects import (
uop,
core,
Expand Down Expand Up @@ -64,6 +67,7 @@ def run_pass(
def main(self):
fold_pass = passes.Fold(self)
typeinfer_pass = passes.TypeInfer(self)
unroll_ifs = UnrollIfs(self)

def run_pass(
method: ir.Method,
Expand All @@ -78,6 +82,10 @@ def run_pass(

typeinfer_pass(method)
method.verify_type()
unroll_ifs(method)

validation_result = ValidationSuite([QASM2Validation]).validate(method)
validation_result.raise_if_invalid()

return run_pass

Expand Down
17 changes: 12 additions & 5 deletions test/qasm2/emit/test_qasm2_emit.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import pytest
from kirin.interp import InterpreterError

from bloqade import qasm2


Expand Down Expand Up @@ -240,8 +237,13 @@ def non_empty_else():

target = qasm2.emit.QASM2()

with pytest.raises(InterpreterError):
had_error = False
try:
target.emit(non_empty_else)
except Exception:
# TODO: this is just to work around ExceptionGroup for now
had_error = True
assert had_error

@qasm2.extended
def multiline_then():
Expand All @@ -256,8 +258,13 @@ def multiline_then():
return q

target = qasm2.emit.QASM2(unroll_ifs=False)
with pytest.raises(InterpreterError):
had_error = False
try:
target.emit(multiline_then)
except Exception:
# TODO: this is just to work around ExceptionGroup for now
had_error = True
assert had_error

@qasm2.extended
def valid_if():
Expand Down
44 changes: 22 additions & 22 deletions test/qasm2/test_count.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from bloqade import qasm2
from bloqade.analysis.address import (
Unknown,
AddressReg,
ConstResult,
AddressQubit,
Expand Down Expand Up @@ -51,27 +50,28 @@ def tuple_count():
assert isinstance(ret.data[1], AddressReg) and ret.data[1].data == range(3, 7)


def test_dynamic_address():
@qasm2.main
def dynamic_address():
ra = qasm2.qreg(3)
rb = qasm2.qreg(4)
ca = qasm2.creg(2)
qasm2.measure(ra[0], ca[0])
qasm2.measure(rb[1], ca[1])
if ca[0] == ca[1]:
ret = ra
else:
ret = rb

return ret

# dynamic_address.code.print()
dynamic_address.print()
fold(dynamic_address)
frame, result = address.run(dynamic_address)
dynamic_address.print(analysis=frame.entries)
assert isinstance(result, Unknown)
# NOTE: this is also invalid for QASM2 - you can't yield from if statements and no else bodies
# def test_dynamic_address():
# @qasm2.main
# def dynamic_address():
# ra = qasm2.qreg(3)
# rb = qasm2.qreg(4)
# ca = qasm2.creg(2)
# qasm2.measure(ra[0], ca[0])
# qasm2.measure(rb[1], ca[1])
# if ca[0] == ca[1]:
# ret = ra
# else:
# ret = rb

# return ret

# # dynamic_address.code.print()
# dynamic_address.print()
# fold(dynamic_address)
# frame, result = address.run(dynamic_address)
# dynamic_address.print(analysis=frame.entries)
# assert isinstance(result, Unknown)


# NOTE: this is invalid for QASM2
Expand Down
Loading