diff --git a/mlir/include/mlir/Dialect/Common/IR/CommonTraits.h b/mlir/include/mlir/Dialect/Common/IR/CommonTraits.h index 23ec89ceb..9009fb5db 100644 --- a/mlir/include/mlir/Dialect/Common/IR/CommonTraits.h +++ b/mlir/include/mlir/Dialect/Common/IR/CommonTraits.h @@ -10,13 +10,41 @@ #pragma once +#include +#include #include +#include #include #include #include +#include namespace mqt::ir::common { -template class TargetArityTrait { + +template struct DefinitionMatrix { + static constexpr std::size_t MatrixSize = 1 << NumQubits; + + template + using MatrixType = std::array; + + MatrixType matrix; + + static constexpr std::size_t index(std::size_t x, std::size_t y) { + return (y * MatrixSize) + x; + } + + constexpr MatrixType getMatrix(mlir::ValueRange params) { + // TODO? lazy-initialized cache + MatrixType result; + static_assert(result.size() == matrix.size()); + for (std::size_t i = 0; i < result.size(); ++i) { + result[i] = matrix[i](params); + } + return result; + } +}; + +template Matrix> class TargetArityTrait { public: template class Impl : public mlir::OpTrait::TraitBase { @@ -29,6 +57,17 @@ template class TargetArityTrait { } return mlir::success(); } + + [[nodiscard]] static auto getDefinitionMatrix() { return Matrix; } + [[nodiscard]] static auto getDefinitionMatrix(mlir::Operation* op) { + auto concreteOp = mlir::cast(op); + return Matrix.getMatrix(concreteOp.getParams()); + } + [[nodiscard]] static double getDefinitionMatrixElement(mlir::Operation* op, + std::size_t x, + std::size_t y) { + return getDefinitionMatrix(op).at(DefinitionMatrix::index(x, y)); + } }; }; diff --git a/mlir/include/mlir/Dialect/Common/IR/CommonTraits.td b/mlir/include/mlir/Dialect/Common/IR/CommonTraits.td index 55849f726..550ec9a60 100644 --- a/mlir/include/mlir/Dialect/Common/IR/CommonTraits.td +++ b/mlir/include/mlir/Dialect/Common/IR/CommonTraits.td @@ -14,8 +14,39 @@ include "mlir/IR/EnumAttr.td" include "mlir/IR/DialectBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" -class TargetArity - : ParamNativeOpTrait<"TargetArityTrait", !cast(N)> { +class DecimalFraction { + int Before = BeforeSeparator; + int After = AfterSeparator; +} +class Int : DecimalFraction; + +class DefinitionElementType { + int id = Id; +} +def Value : DefinitionElementType<0>; +def ParameterIndex : DefinitionElementType<1>; + +class FunctionType { + int id = Id; +} + +def Identity : FunctionType<0>; +def Sin : FunctionType<1>; +def Cos : FunctionType<2>; + +class DefinitionElement { + DecimalFraction v = V; + DefinitionElementType t = T; + FunctionType f = F; +} + +class Value : DefinitionElement; + +class TargetArity MatrixDefinition = []> + : ParamNativeOpTrait< + "TargetArityTrait", + !cast(N) # ", {{" # !foldl("", MatrixDefinition, acc, var, "[]([[maybe_unused]] mlir::ValueRange params) { return " # !cond(!eq(var.f, Identity) : "std::identity{}(") # !cond(!eq(var.t, Value) : var.v.Before # "." # var.v.After) # "); }") # "}}" + > { let cppNamespace = "::mqt::ir::common"; } diff --git a/mlir/include/mlir/Dialect/Common/IR/StdOps.td.inc b/mlir/include/mlir/Dialect/Common/IR/StdOps.td.inc index 675c4b992..ebf57c77a 100644 --- a/mlir/include/mlir/Dialect/Common/IR/StdOps.td.inc +++ b/mlir/include/mlir/Dialect/Common/IR/StdOps.td.inc @@ -22,7 +22,7 @@ def GPhaseOp : UnitaryOp<"gphase", [NoTarget, OneParameter]> { }]; } -def IOp : UnitaryOp<"i", [OneTarget, NoParameter]> { +def IOp : UnitaryOp<"i", [TargetArity<1, [Value>, Value>, Value>, Value>]>, NoParameter]> { let summary = "I operation"; let description = [{ diff --git a/mlir/include/mlir/Dialect/MQTOpt/IR/MQTOptInterfaces.td b/mlir/include/mlir/Dialect/MQTOpt/IR/MQTOptInterfaces.td index 7562057fe..f856d0a1b 100644 --- a/mlir/include/mlir/Dialect/MQTOpt/IR/MQTOptInterfaces.td +++ b/mlir/include/mlir/Dialect/MQTOpt/IR/MQTOptInterfaces.td @@ -167,6 +167,35 @@ def UnitaryInterface : OpInterface<"UnitaryInterface"> { /*methodBody=*/ [{}], /*defaultImpl=*/ [{ return $_op->getName().getStringRef().split('.').second; + }]>, + InterfaceMethod< + /*desc=*/ "Check if operation operates on a single qubit.", + /*returnType=*/ "bool", + /*methodName=*/ "isSingleQubitOperation", + /*args=*/ (ins), + /*methodBody=*/ [{}], + /*defaultImpl=*/ [{ + auto&& inQubits = $_op.getInQubits(); + auto&& outQubits = $_op.getOutQubits(); + return inQubits.size() == 1 && outQubits.size() == 1 && !$_op.isControlled(); + }]>, + InterfaceMethod< + /*desc=*/ "Check if operation operates on exactly two qubits.", + /*returnType=*/ "bool", + /*methodName=*/ "isTwoQubitOperation", + /*args=*/ (ins), + /*methodBody=*/ [{}], + /*defaultImpl=*/ [{ + auto&& inQubits = $_op.getInQubits(); + auto&& inPosCtrlQubits = $_op.getPosCtrlInQubits(); + auto&& inNegCtrlQubits = $_op.getNegCtrlInQubits(); + auto&& outQubits = $_op.getOutQubits(); + auto&& outPosCtrlQubits = $_op.getPosCtrlInQubits(); + auto&& outNegCtrlQubits = $_op.getNegCtrlInQubits(); + + auto inQubitSize = inQubits.size() + inPosCtrlQubits.size() + inNegCtrlQubits.size(); + auto outQubitSize = outQubits.size() + outPosCtrlQubits.size() + outNegCtrlQubits.size(); + return inQubitSize == 2 && outQubitSize == 2; }]> ];