Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 56 additions & 75 deletions mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/ValueRange.h"

#include "Quantum/IR/QuantumInterfaces.h"
#include "Quantum/IR/QuantumOps.h"
#include "Quantum/Transforms/Patterns.h"

Expand All @@ -40,6 +41,7 @@ namespace quantum {
/// - A runtime Value (for dynamic indices computed at runtime)
/// - An IntegerAttr (for compile-time constant indices)
/// - Invalid/uninitialized (represented by std::monostate)
/// And a qreg value to represent the qreg that the index belongs to
///
/// The struct uses std::variant to ensure only one type is active at a time,
/// preventing invalid states.
Expand All @@ -54,17 +56,21 @@ namespace quantum {
/// Value idx = dynamicIdx.getValue(); // Get the Value
/// }
/// }
struct QubitIndex {
class QubitIndex {
private:
// use monostate to represent the invalid index
std::variant<std::monostate, Value, IntegerAttr> index;
Value qreg;

QubitIndex() : index(std::monostate()) {}
QubitIndex(Value val) : index(val) {}
QubitIndex(IntegerAttr attr) : index(attr) {}
public:
QubitIndex() : index(std::monostate()), qreg(nullptr) {}
QubitIndex(Value val, Value qreg) : index(val), qreg(qreg) {}
QubitIndex(IntegerAttr attr, Value qreg) : index(attr), qreg(qreg) {}

bool isValue() const { return std::holds_alternative<Value>(index); }
bool isAttr() const { return std::holds_alternative<IntegerAttr>(index); }
operator bool() const { return isValue() || isAttr(); }
Value getReg() const { return qreg; }
Value getValue() const { return isValue() ? std::get<Value>(index) : nullptr; }
IntegerAttr getAttr() const { return isAttr() ? std::get<IntegerAttr>(index) : nullptr; }
};
Expand All @@ -76,25 +82,16 @@ class OpSignatureAnalyzer {
public:
OpSignatureAnalyzer() = delete;
OpSignatureAnalyzer(CustomOp op, bool enableQregMode)
: signature(OpSignature{
.params = op.getParams(),
.inQubits = op.getInQubits(),
.inCtrlQubits = op.getInCtrlQubits(),
.inCtrlValues = op.getInCtrlValues(),
.outQubits = op.getOutQubits(),
.outCtrlQubits = op.getOutCtrlQubits(),
})
: signature(OpSignature{.params = op.getParams(),
.inQubits = op.getNonCtrlQubitOperands(),
.inCtrlQubits = op.getCtrlQubitOperands(),
.inCtrlValues = op.getCtrlValueOperands(),
.outQubits = op.getNonCtrlQubitResults(),
.outCtrlQubits = op.getCtrlQubitResults()})
{
if (!enableQregMode)
return;

signature.sourceQreg = getSourceQreg(signature.inQubits.front());
if (!signature.sourceQreg) {
op.emitError("Cannot get source qreg");
isValid = false;
return;
}

// input wire indices
for (Value qubit : signature.inQubits) {
const QubitIndex index = getExtractIndex(qubit);
Expand All @@ -117,13 +114,23 @@ class OpSignatureAnalyzer {
signature.inCtrlWireIndices.emplace_back(index);
}

assert((signature.inWireIndices.size() + signature.inCtrlWireIndices.size()) > 0 &&
"inWireIndices or inCtrlWireIndices should not be empty");

// Output qubit indices are the same as input qubit indices
signature.outQubitIndices = signature.inWireIndices;
signature.outCtrlQubitIndices = signature.inCtrlWireIndices;
}

operator bool() const { return isValid; }

Value getUpdatedQreg(PatternRewriter &rewriter, Location loc)
{
// FIXME: This will cause an issue when the decomposition function has cross-qreg
// inputs and outputs. Now, we just assume has only one qreg input, the global one exists.
return signature.inWireIndices[0].getReg();
}

// Prepare the operands for calling the decomposition function
// There are two cases:
// 1. The first input is a qreg, which means the decomposition function is a qreg mode function
Expand All @@ -144,15 +151,8 @@ class OpSignatureAnalyzer {

int operandIdx = 0;
if (isa<quantum::QuregType>(funcInputs[0])) {
Value updatedQreg = signature.sourceQreg;
for (auto [i, qubit] : llvm::enumerate(signature.inQubits)) {
const QubitIndex &index = signature.inWireIndices[i];
updatedQreg =
rewriter.create<quantum::InsertOp>(loc, updatedQreg.getType(), updatedQreg,
index.getValue(), index.getAttr(), qubit);
}
operands[operandIdx++] = getUpdatedQreg(rewriter, loc);

operands[operandIdx++] = updatedQreg;
if (!signature.params.empty()) {
auto [startIdx, endIdx] =
findParamTypeRange(funcInputs, signature.params.size(), operandIdx);
Expand All @@ -163,16 +163,12 @@ class OpSignatureAnalyzer {
}
}

if (!signature.inWireIndices.empty()) {
operands[operandIdx] = fromTensorOrAsIs(signature.inWireIndices,
funcInputs[operandIdx], rewriter, loc);
operandIdx++;
}

if (!signature.inCtrlWireIndices.empty()) {
operands[operandIdx] = fromTensorOrAsIs(signature.inCtrlWireIndices,
funcInputs[operandIdx], rewriter, loc);
operandIdx++;
for (const auto &indices : {signature.inWireIndices, signature.inCtrlWireIndices}) {
if (!indices.empty()) {
operands[operandIdx] =
fromTensorOrAsIs(indices, funcInputs[operandIdx], rewriter, loc);
operandIdx++;
}
}
}
else {
Expand Down Expand Up @@ -218,18 +214,16 @@ class OpSignatureAnalyzer {

SmallVector<Value> newResults;
rewriter.setInsertionPointAfter(callOp);
for (const QubitIndex &index : signature.outQubitIndices) {
auto extractOp = rewriter.create<quantum::ExtractOp>(
callOp.getLoc(), rewriter.getType<quantum::QubitType>(), qreg, index.getValue(),
index.getAttr());
newResults.emplace_back(extractOp.getResult());
}
for (const QubitIndex &index : signature.outCtrlQubitIndices) {
auto extractOp = rewriter.create<quantum::ExtractOp>(
callOp.getLoc(), rewriter.getType<quantum::QubitType>(), qreg, index.getValue(),
index.getAttr());
newResults.emplace_back(extractOp.getResult());

for (const auto &indices : {signature.outQubitIndices, signature.outCtrlQubitIndices}) {
for (const auto &index : indices) {
auto extractOp = rewriter.create<quantum::ExtractOp>(
callOp.getLoc(), rewriter.getType<quantum::QubitType>(), qreg, index.getValue(),
index.getAttr());
newResults.emplace_back(extractOp.getResult());
}
}

return newResults;
}

Expand All @@ -245,7 +239,6 @@ class OpSignatureAnalyzer {
ValueRange outCtrlQubits;

// Qreg mode specific information
Value sourceQreg = nullptr;
SmallVector<QubitIndex> inWireIndices;
SmallVector<QubitIndex> inCtrlWireIndices;
SmallVector<QubitIndex> outQubitIndices;
Expand Down Expand Up @@ -333,39 +326,21 @@ class OpSignatureAnalyzer {
return values.front();
}

Value getSourceQreg(Value qubit)
{
while (qubit) {
if (auto extractOp = qubit.getDefiningOp<quantum::ExtractOp>()) {
return extractOp.getQreg();
}

if (auto customOp = dyn_cast_or_null<quantum::CustomOp>(qubit.getDefiningOp())) {
if (customOp.getQubitOperands().empty()) {
break;
}
qubit = customOp.getQubitOperands()[0];
}
}

return nullptr;
}

QubitIndex getExtractIndex(Value qubit)
{
while (qubit) {
if (auto extractOp = qubit.getDefiningOp<quantum::ExtractOp>()) {
if (Value idx = extractOp.getIdx()) {
return QubitIndex(idx);
return QubitIndex(idx, extractOp.getQreg());
}
if (IntegerAttr idxAttr = extractOp.getIdxAttrAttr()) {
return QubitIndex(idxAttr);
return QubitIndex(idxAttr, extractOp.getQreg());
}
}

if (auto customOp = dyn_cast_or_null<quantum::CustomOp>(qubit.getDefiningOp())) {
auto qubitOperands = customOp.getQubitOperands();
auto qubitResults = customOp.getQubitResults();
if (auto gate = dyn_cast_or_null<quantum::QuantumGate>(qubit.getDefiningOp())) {
auto qubitOperands = gate.getQubitOperands();
auto qubitResults = gate.getQubitResults();
auto it =
llvm::find_if(qubitResults, [&](Value result) { return result == qubit; });

Expand All @@ -377,6 +352,10 @@ class OpSignatureAnalyzer {
}
}
}
else if (auto measureOp = dyn_cast_or_null<quantum::MeasureOp>(qubit.getDefiningOp())) {
qubit = measureOp.getInQubit();
continue;
}

break;
}
Expand All @@ -394,7 +373,8 @@ struct DecomposeLoweringRewritePattern : public OpRewritePattern<CustomOp> {
DecomposeLoweringRewritePattern(MLIRContext *context,
const llvm::StringMap<func::FuncOp> &registry,
const llvm::StringSet<llvm::MallocAllocator> &gateSet)
: OpRewritePattern(context), decompositionRegistry(registry), targetGateSet(gateSet)
: OpRewritePattern<CustomOp>(context), decompositionRegistry(registry),
targetGateSet(gateSet)
{
}

Expand All @@ -421,11 +401,12 @@ struct DecomposeLoweringRewritePattern : public OpRewritePattern<CustomOp> {
assert(decompFunc.getFunctionType().getNumResults() >= 1 &&
"Decomposition function must have at least one result");

rewriter.setInsertionPointAfter(op);

auto enableQreg = isa<quantum::QuregType>(decompFunc.getFunctionType().getInput(0));
auto analyzer = OpSignatureAnalyzer(op, enableQreg);
assert(analyzer && "Analyzer should be valid");

rewriter.setInsertionPointAfter(op);
auto callOperands = analyzer.prepareCallOperands(decompFunc, rewriter, op.getLoc());
auto callOp =
rewriter.create<func::CallOp>(op.getLoc(), decompFunc.getFunctionType().getResults(),
Expand Down Expand Up @@ -453,4 +434,4 @@ void populateDecomposeLoweringPatterns(RewritePatternSet &patterns,
}

} // namespace quantum
} // namespace catalyst
} // namespace catalyst
46 changes: 46 additions & 0 deletions mlir/test/Quantum/DecomposeLoweringTest.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -508,3 +508,49 @@ module @cnot_alternative_decomposition {
return %out_qubits_2#0, %out_qubits_4 : !quantum.bit, !quantum.bit
}
}

// -----

module @mcm_example {
func.func public @test_mcm_hadamard() -> tensor<2xf64> {
%0 = quantum.alloc( 1) : !quantum.reg
%1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit
%mres, %out_qubit = quantum.measure %1 : i1, !quantum.bit
%2 = quantum.insert %0[ 0], %out_qubit : !quantum.reg, !quantum.bit

// CHECK: [[RZ_QUBIT:%.+]] = quantum.custom "RZ"([[CST_0:%.+]])
// CHECK: [[RY_QUBIT:%.+]] = quantum.custom "RY"([[CST_1:%.+]]) [[RZ_QUBIT]] : !quantum.bit
// CHECK: [[REG_1:%.+]] = quantum.insert [[REG:%.+]][[[EXTRACTED:%.+]]], [[RY_QUBIT]] : !quantum.reg, !quantum.bit
// CHECK-NOT: quantum.custom "Hadamard"
%3 = quantum.extract %2[ 0] : !quantum.reg -> !quantum.bit
%out_qubits = quantum.custom "Hadamard"() %3 : !quantum.bit
%4 = quantum.insert %2[ 0], %out_qubits : !quantum.reg, !quantum.bit

%5 = quantum.compbasis qreg %4 : !quantum.obs
%6 = quantum.probs %5 : tensor<2xf64>
quantum.dealloc %4 : !quantum.reg
return %6 : tensor<2xf64>
}

// Decomposition function should be applied and removed from the module
// CHECK-NOT: func.func public @rz_ry
func.func public @rz_ry(%arg0: !quantum.reg, %arg1: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage<internal>, num_wires = 1 : i64, target_gate = "Hadamard"} {
%cst = arith.constant 3.1415926535897931 : f64
%cst_0 = arith.constant 1.5707963267948966 : f64
%0 = stablehlo.slice %arg1 [0:1] : (tensor<1xi64>) -> tensor<1xi64>
%1 = stablehlo.reshape %0 : (tensor<1xi64>) -> tensor<i64>
%extracted = tensor.extract %1[] : tensor<i64>
%2 = quantum.extract %arg0[%extracted] : !quantum.reg -> !quantum.bit
%out_qubits = quantum.custom "RZ"(%cst_0) %2 : !quantum.bit
%3 = stablehlo.slice %arg1 [0:1] : (tensor<1xi64>) -> tensor<1xi64>
%4 = stablehlo.reshape %3 : (tensor<1xi64>) -> tensor<i64>
%extracted_1 = tensor.extract %1[] : tensor<i64>
%5 = quantum.insert %arg0[%extracted_1], %out_qubits : !quantum.reg, !quantum.bit
%extracted_2 = tensor.extract %4[] : tensor<i64>
%6 = quantum.extract %5[%extracted_2] : !quantum.reg -> !quantum.bit
%out_qubits_3 = quantum.custom "RY"(%cst) %6 : !quantum.bit
%extracted_4 = tensor.extract %4[] : tensor<i64>
%7 = quantum.insert %5[%extracted_4], %out_qubits_3 : !quantum.reg, !quantum.bit
return %7 : !quantum.reg
}
}