Skip to content
Draft
41 changes: 40 additions & 1 deletion mlir/include/mlir/Dialect/Common/IR/CommonTraits.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,41 @@

#pragma once

#include <array>
#include <cmath>
#include <cstddef>
#include <functional>
#include <mlir/IR/OpDefinition.h>
#include <mlir/IR/Operation.h>
#include <mlir/Support/LLVM.h>
#include <stdexcept>

namespace mqt::ir::common {
template <size_t N> class TargetArityTrait {

template <std::size_t NumQubits> struct DefinitionMatrix {
static constexpr std::size_t MatrixSize = 1 << NumQubits;

template <typename T>
using MatrixType = std::array<T, MatrixSize * MatrixSize>;

MatrixType<double (*)(mlir::ValueRange)> matrix;

static constexpr std::size_t index(std::size_t x, std::size_t y) {
return (y * MatrixSize) + x;
}

constexpr MatrixType<double> getMatrix(mlir::ValueRange params) {
// TODO? lazy-initialized cache
MatrixType<double> result;
static_assert(result.size() == matrix.size());
for (std::size_t i = 0; i < result.size(); ++i) {
result[i] = matrix[i](params);
}
return result;
}
};

template <size_t N, DefinitionMatrix<N> Matrix> class TargetArityTrait {
public:
template <typename ConcreteOp>
class Impl : public mlir::OpTrait::TraitBase<ConcreteOp, Impl> {
Expand All @@ -29,6 +57,17 @@ template <size_t N> class TargetArityTrait {
}
return mlir::success();
}

[[nodiscard]] static auto getDefinitionMatrix() { return Matrix; }
[[nodiscard]] static auto getDefinitionMatrix(mlir::Operation* op) {
auto concreteOp = mlir::cast<ConcreteOp>(op);
return Matrix.getMatrix(concreteOp.getParams());
}
[[nodiscard]] static double getDefinitionMatrixElement(mlir::Operation* op,
std::size_t x,
std::size_t y) {
return getDefinitionMatrix(op).at(DefinitionMatrix<N>::index(x, y));
}
};
};

Expand Down
35 changes: 33 additions & 2 deletions mlir/include/mlir/Dialect/Common/IR/CommonTraits.td
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,39 @@ include "mlir/IR/EnumAttr.td"
include "mlir/IR/DialectBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"

class TargetArity<int N>
: ParamNativeOpTrait<"TargetArityTrait", !cast<string>(N)> {
class DecimalFraction<int BeforeSeparator, int AfterSeparator> {
int Before = BeforeSeparator;
int After = AfterSeparator;
}
class Int<int I> : DecimalFraction<I, 0>;

class DefinitionElementType<int Id> {
int id = Id;
}
def Value : DefinitionElementType<0>;
def ParameterIndex : DefinitionElementType<1>;

class FunctionType<int Id> {
int id = Id;
}

def Identity : FunctionType<0>;
def Sin : FunctionType<1>;
def Cos : FunctionType<2>;

class DefinitionElement<DecimalFraction V, DefinitionElementType T, FunctionType F = Identity> {
DecimalFraction v = V;
DefinitionElementType t = T;
FunctionType f = F;
}

class Value<DecimalFraction V> : DefinitionElement<V, Value, Identity>;

class TargetArity<int N, list<DefinitionElement> MatrixDefinition = []>
: ParamNativeOpTrait<
"TargetArityTrait",
!cast<string>(N) # ", {{" # !foldl("", MatrixDefinition, acc, var, "[]([[maybe_unused]] mlir::ValueRange params) { return " # !cond(!eq(var.f, Identity) : "std::identity{}(") # !cond(!eq(var.t, Value) : var.v.Before # "." # var.v.After) # "); }") # "}}"
> {
let cppNamespace = "::mqt::ir::common";
}

Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/Common/IR/StdOps.td.inc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def GPhaseOp : UnitaryOp<"gphase", [NoTarget, OneParameter]> {
}];
}

def IOp : UnitaryOp<"i", [OneTarget, NoParameter]> {
def IOp : UnitaryOp<"i", [TargetArity<1, [Value<Int<1>>, Value<Int<0>>, Value<Int<0>>, Value<Int<1>>]>, NoParameter]> {
let summary = "I operation";

let description = [{
Expand Down
29 changes: 29 additions & 0 deletions mlir/include/mlir/Dialect/MQTOpt/IR/MQTOptInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,35 @@ def UnitaryInterface : OpInterface<"UnitaryInterface"> {
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
return $_op->getName().getStringRef().split('.').second;
}]>,
InterfaceMethod<
/*desc=*/ "Check if operation operates on a single qubit.",
/*returnType=*/ "bool",
/*methodName=*/ "isSingleQubitOperation",
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
auto&& inQubits = $_op.getInQubits();
auto&& outQubits = $_op.getOutQubits();
return inQubits.size() == 1 && outQubits.size() == 1 && !$_op.isControlled();
}]>,
InterfaceMethod<
/*desc=*/ "Check if operation operates on exactly two qubits.",
/*returnType=*/ "bool",
/*methodName=*/ "isTwoQubitOperation",
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
auto&& inQubits = $_op.getInQubits();
auto&& inPosCtrlQubits = $_op.getPosCtrlInQubits();
auto&& inNegCtrlQubits = $_op.getNegCtrlInQubits();
auto&& outQubits = $_op.getOutQubits();
auto&& outPosCtrlQubits = $_op.getPosCtrlInQubits();
auto&& outNegCtrlQubits = $_op.getNegCtrlInQubits();

auto inQubitSize = inQubits.size() + inPosCtrlQubits.size() + inNegCtrlQubits.size();
auto outQubitSize = outQubits.size() + outPosCtrlQubits.size() + outNegCtrlQubits.size();
return inQubitSize == 2 && outQubitSize == 2;
}]>
];

Expand Down
Loading