From ef935d0ca10e24efb1d89a999c8aa3d812321503 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Thu, 20 Nov 2025 16:06:36 +0100 Subject: [PATCH] Validation for QASM2 main kernels --- src/bloqade/qasm2/__init__.py | 1 + src/bloqade/qasm2/analysis/__init__.py | 1 + .../qasm2/analysis/validation/__init__.py | 0 .../qasm2/analysis/validation/analysis.py | 121 ++++++++++++++++++ src/bloqade/qasm2/emit/target.py | 4 + src/bloqade/qasm2/groups.py | 8 ++ test/qasm2/emit/test_qasm2_emit.py | 17 ++- test/qasm2/test_count.py | 44 +++---- 8 files changed, 169 insertions(+), 27 deletions(-) create mode 100644 src/bloqade/qasm2/analysis/__init__.py create mode 100644 src/bloqade/qasm2/analysis/validation/__init__.py create mode 100644 src/bloqade/qasm2/analysis/validation/analysis.py diff --git a/src/bloqade/qasm2/__init__.py b/src/bloqade/qasm2/__init__.py index a6c615170..0715b75d2 100644 --- a/src/bloqade/qasm2/__init__.py +++ b/src/bloqade/qasm2/__init__.py @@ -4,6 +4,7 @@ emit as emit, glob as glob, parse as parse, + analysis as analysis, dialects as dialects, parallel as parallel, ) diff --git a/src/bloqade/qasm2/analysis/__init__.py b/src/bloqade/qasm2/analysis/__init__.py new file mode 100644 index 000000000..e864ef56d --- /dev/null +++ b/src/bloqade/qasm2/analysis/__init__.py @@ -0,0 +1 @@ +from .validation.analysis import QASM2Validation as QASM2Validation diff --git a/src/bloqade/qasm2/analysis/validation/__init__.py b/src/bloqade/qasm2/analysis/validation/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/bloqade/qasm2/analysis/validation/analysis.py b/src/bloqade/qasm2/analysis/validation/analysis.py new file mode 100644 index 000000000..596413f94 --- /dev/null +++ b/src/bloqade/qasm2/analysis/validation/analysis.py @@ -0,0 +1,121 @@ +from typing import Any + +from kirin import ir, interp +from kirin.lattice import EmptyLattice +from kirin.analysis import Forward +from kirin.dialects import scf +from kirin.validation import ValidationPass +from kirin.analysis.forward import ForwardFrame + +from bloqade.qasm2.passes.unroll_if import DontLiftType + + +class _QASM2ValidationAnalysis(Forward[EmptyLattice]): + keys = ["qasm2.main.validation"] + + lattice = EmptyLattice + + def method_self(self, method: ir.Method) -> EmptyLattice: + return self.lattice.bottom() + + def eval_fallback( + self, frame: ForwardFrame[EmptyLattice], node: ir.Statement + ) -> tuple[EmptyLattice, ...]: + return tuple(self.lattice.bottom() for _ in range(len(node.results))) + + +@scf.dialect.register(key="qasm2.main.validation") +class __ScfMethods(interp.MethodTable): + + @interp.impl(scf.IfElse) + def if_else( + self, + interp_: _QASM2ValidationAnalysis, + frame: ForwardFrame[EmptyLattice], + stmt: scf.IfElse, + ): + + # TODO: stmt.condition has to be based off a measurement + + if len(stmt.then_body.blocks) > 1: + interp_.add_validation_error( + stmt, + ir.ValidationError( + stmt, + "Only single block is allowed in the then-body of an if-else statement!", + ), + ) + + then_stmts = list(stmt.then_body.stmts()) + if len(then_stmts) > 2: + interp_.add_validation_error( + stmt, + ir.ValidationError( + stmt, + "Only single statements are allowed inside the then-body of an if-else statement!", + ), + ) + + if not isinstance(then_stmts[0], DontLiftType): + interp_.add_validation_error( + stmt, + ir.ValidationError( + stmt, f"Statement {then_stmts[0]} not allowed inside if clause!" + ), + ) + + self.__validate_empty_yield(interp_, then_stmts[-1]) + + if len(stmt.else_body.blocks) > 1: + interp_.add_validation_error( + stmt, + ir.ValidationError( + stmt, + "Only single block is allowed in the else-body of an if-else statement!", + ), + ) + + else_stmts = list(stmt.else_body.stmts()) + if len(else_stmts) > 1: + interp_.add_validation_error( + stmt, + ir.ValidationError(stmt, "Non-empty else is not allowed in QASM2!"), + ) + + self.__validate_empty_yield(interp_, else_stmts[-1]) + + def __validate_empty_yield( + self, interp_: _QASM2ValidationAnalysis, stmt: ir.Statement + ): + if not isinstance(stmt, scf.Yield): + interp_.add_validation_error( + stmt, + ir.ValidationError( + stmt, f"Expected scf.Yield terminator in if clause, got {stmt}" + ), + ) + elif len(stmt.values) > 0: + interp_.add_validation_error( + stmt, ir.ValidationError(stmt, "Cannot yield values from if statement!") + ) + + @interp.impl(scf.For) + def for_loop( + self, + interp_: _QASM2ValidationAnalysis, + frame: ForwardFrame[EmptyLattice], + stmt: scf.For, + ): + interp_.add_validation_error( + stmt, ir.ValidationError(stmt, "Loops not supported in QASM2!") + ) + + +class QASM2Validation(ValidationPass): + def name(self) -> str: + return "QASM2 validation" + + def run(self, method: ir.Method) -> tuple[Any, list[ir.ValidationError]]: + analysis = _QASM2ValidationAnalysis(method.dialects) + frame, _ = analysis.run(method) + return frame, analysis.get_validation_errors() diff --git a/src/bloqade/qasm2/emit/target.py b/src/bloqade/qasm2/emit/target.py index 784034a43..33834993a 100644 --- a/src/bloqade/qasm2/emit/target.py +++ b/src/bloqade/qasm2/emit/target.py @@ -4,6 +4,7 @@ from rich.console import Console from kirin.analysis import CallGraph from kirin.dialects import ilist +from kirin.validation import ValidationSuite from bloqade.qasm2.parse import ast, pprint from bloqade.qasm2.passes.fold import QASM2Fold @@ -14,6 +15,7 @@ from . import impls as impls # register the tables from .gate import EmitQASM2Gate from .main import EmitQASM2Main +from ..analysis import QASM2Validation class QASM2: @@ -114,6 +116,8 @@ def emit(self, entry: ir.Method) -> ast.MainProgram: # rewrite parallel to uop ParallelToUOp(dialects=entry.dialects)(entry) + ValidationSuite([QASM2Validation]).validate(entry).raise_if_invalid() + Py2QASM(entry.dialects)(entry) target_main = EmitQASM2Main(self.main_target).initialize() target_main.run(entry) diff --git a/src/bloqade/qasm2/groups.py b/src/bloqade/qasm2/groups.py index 4e4955627..222cbaec0 100644 --- a/src/bloqade/qasm2/groups.py +++ b/src/bloqade/qasm2/groups.py @@ -1,7 +1,10 @@ from kirin import ir, passes from kirin.prelude import structural_no_opt from kirin.dialects import scf, func, ilist, ssacfg, lowering +from kirin.validation import ValidationSuite +from bloqade.qasm2.passes import UnrollIfs +from bloqade.qasm2.analysis import QASM2Validation from bloqade.qasm2.dialects import ( uop, core, @@ -64,6 +67,7 @@ def run_pass( def main(self): fold_pass = passes.Fold(self) typeinfer_pass = passes.TypeInfer(self) + unroll_ifs = UnrollIfs(self) def run_pass( method: ir.Method, @@ -78,6 +82,10 @@ def run_pass( typeinfer_pass(method) method.verify_type() + unroll_ifs(method) + + validation_result = ValidationSuite([QASM2Validation]).validate(method) + validation_result.raise_if_invalid() return run_pass diff --git a/test/qasm2/emit/test_qasm2_emit.py b/test/qasm2/emit/test_qasm2_emit.py index f7e776eb7..338e3368f 100644 --- a/test/qasm2/emit/test_qasm2_emit.py +++ b/test/qasm2/emit/test_qasm2_emit.py @@ -1,6 +1,3 @@ -import pytest -from kirin.interp import InterpreterError - from bloqade import qasm2 @@ -240,8 +237,13 @@ def non_empty_else(): target = qasm2.emit.QASM2() - with pytest.raises(InterpreterError): + had_error = False + try: target.emit(non_empty_else) + except Exception: + # TODO: this is just to work around ExceptionGroup for now + had_error = True + assert had_error @qasm2.extended def multiline_then(): @@ -256,8 +258,13 @@ def multiline_then(): return q target = qasm2.emit.QASM2(unroll_ifs=False) - with pytest.raises(InterpreterError): + had_error = False + try: target.emit(multiline_then) + except Exception: + # TODO: this is just to work around ExceptionGroup for now + had_error = True + assert had_error @qasm2.extended def valid_if(): diff --git a/test/qasm2/test_count.py b/test/qasm2/test_count.py index 32aed54bd..b71d27273 100644 --- a/test/qasm2/test_count.py +++ b/test/qasm2/test_count.py @@ -3,7 +3,6 @@ from bloqade import qasm2 from bloqade.analysis.address import ( - Unknown, AddressReg, ConstResult, AddressQubit, @@ -51,27 +50,28 @@ def tuple_count(): assert isinstance(ret.data[1], AddressReg) and ret.data[1].data == range(3, 7) -def test_dynamic_address(): - @qasm2.main - def dynamic_address(): - ra = qasm2.qreg(3) - rb = qasm2.qreg(4) - ca = qasm2.creg(2) - qasm2.measure(ra[0], ca[0]) - qasm2.measure(rb[1], ca[1]) - if ca[0] == ca[1]: - ret = ra - else: - ret = rb - - return ret - - # dynamic_address.code.print() - dynamic_address.print() - fold(dynamic_address) - frame, result = address.run(dynamic_address) - dynamic_address.print(analysis=frame.entries) - assert isinstance(result, Unknown) +# NOTE: this is also invalid for QASM2 - you can't yield from if statements and no else bodies +# def test_dynamic_address(): +# @qasm2.main +# def dynamic_address(): +# ra = qasm2.qreg(3) +# rb = qasm2.qreg(4) +# ca = qasm2.creg(2) +# qasm2.measure(ra[0], ca[0]) +# qasm2.measure(rb[1], ca[1]) +# if ca[0] == ca[1]: +# ret = ra +# else: +# ret = rb + +# return ret + +# # dynamic_address.code.print() +# dynamic_address.print() +# fold(dynamic_address) +# frame, result = address.run(dynamic_address) +# dynamic_address.print(analysis=frame.entries) +# assert isinstance(result, Unknown) # NOTE: this is invalid for QASM2