diff --git a/CHANGELOG.md b/CHANGELOG.md index 3df6a465d6..4d0330315a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -42,7 +42,7 @@ with the exception that minor releases may include breaking changes. [#1569], [#1570], [#1572], [#1573], [#1580], [#1602], [#1620], [#1623], [#1624], [#1626], [#1627], [#1635], [#1638], [#1673], [#1675], [#1700], [#1710], [#1717], [#1728], [#1730], [#1749], [#1751], [#1762], [#1765], - [#1774], [#1781], [#1782]) + [#1774], [#1781], [#1782], [#1787]) ([**@burgholzer**], [**@denialhaag**], [**@taminob**], [**@DRovara**], [**@li-mingbao**], [**@Ectras**], [**@MatthiasReumann**], [**@simon1hofmann**]) @@ -598,6 +598,7 @@ changelogs._ +[#1787]: https://github.com/munich-quantum-toolkit/core/pull/1787 [#1782]: https://github.com/munich-quantum-toolkit/core/pull/1782 [#1781]: https://github.com/munich-quantum-toolkit/core/pull/1781 [#1776]: https://github.com/munich-quantum-toolkit/core/pull/1776 diff --git a/mlir/include/mlir/Conversion/QCToQIR/QIRCommon/QIRCommon.h b/mlir/include/mlir/Conversion/QCToQIR/QIRCommon/QIRCommon.h index c92bf5f9d2..2434359f3a 100644 --- a/mlir/include/mlir/Conversion/QCToQIR/QIRCommon/QIRCommon.h +++ b/mlir/include/mlir/Conversion/QCToQIR/QIRCommon/QIRCommon.h @@ -10,8 +10,6 @@ #pragma once -#include "mlir/Dialect/QIR/Utils/QIRMetadata.h" - #include #include #include @@ -26,7 +24,6 @@ #include namespace mlir { -using namespace qir; /** @brief Qubit allocation mode */ enum class AllocationMode : std::uint8_t { @@ -38,7 +35,7 @@ enum class AllocationMode : std::uint8_t { /** * @brief State object for tracking lowering information during QIR conversion */ -struct LoweringState : QIRMetadata { +struct LoweringState { /// Cache static qubit pointers for reuse DenseMap staticQubits; diff --git a/mlir/include/mlir/Dialect/QIR/Builder/QIRProgramBuilder.h b/mlir/include/mlir/Dialect/QIR/Builder/QIRProgramBuilder.h index 30802171eb..3ad99aa526 100644 --- a/mlir/include/mlir/Dialect/QIR/Builder/QIRProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QIR/Builder/QIRProgramBuilder.h @@ -10,8 +10,6 @@ #pragma once -#include "mlir/Dialect/QIR/Utils/QIRMetadata.h" - #include #include #include @@ -1086,15 +1084,18 @@ class QIRProgramBuilder final : public ImplicitLocOpBuilder { /// Map from register to their loaded indices DenseMap> loadedQubits; - /// Track qubit and result counts for QIR metadata - QIRMetadata metadata_; - /// Helper variable for storing the LLVM pointer type Type ptrType; /// Helper variable for storing the LLVM void type Type voidType; + /// The number of used qubits. + size_t numQubits{0}; + + /// The number of result values. + size_t numResults{0}; + /** * @brief Helper to create a LLVM CallOp * diff --git a/mlir/include/mlir/Dialect/QIR/Transforms/Passes.td b/mlir/include/mlir/Dialect/QIR/Transforms/Passes.td index efb7fea4b1..d88f5d79fa 100644 --- a/mlir/include/mlir/Dialect/QIR/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/QIR/Transforms/Passes.td @@ -11,6 +11,15 @@ include "mlir/Pass/PassBase.td" +def QIRSetAttributesAndMetadata + : Pass<"set-qir-attributes-and-metadata", "mlir::ModuleOp"> { + let dependentDialects = ["mlir::LLVM::LLVMDialect"]; + let summary = "Sets the required attributes to the entry point function and " + "adds the required module flags, compliant with QIR 2.1"; + let options = [Option<"useAdaptive", "use-adaptive", "bool", + /*default= */ "true", "Specifies the profile.">]; +} + def QIRCleanupPass : Pass<"qir-cleanup", "mlir::ModuleOp"> { let dependentDialects = ["mlir::LLVM::LLVMDialect"]; let summary = "Remove redundant QIR runtime bookkeeping."; diff --git a/mlir/include/mlir/Dialect/QIR/Utils/QIRMetadata.h b/mlir/include/mlir/Dialect/QIR/Utils/QIRMetadata.h deleted file mode 100644 index d089b1b78b..0000000000 --- a/mlir/include/mlir/Dialect/QIR/Utils/QIRMetadata.h +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Copyright (c) 2023 - 2026 Chair for Design Automation, TUM - * Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH - * All rights reserved. - * - * SPDX-License-Identifier: MIT - * - * Licensed under the MIT License - */ - -#pragma once - -#include - -namespace mlir::qir { - -/** - * @brief State object for tracking QIR metadata during conversion - * - * @details - * This struct maintains metadata about the QIR program being built: - * - Qubit and result counts for QIR metadata - * - Whether dynamic memory management is needed - */ -struct QIRMetadata { - /// Number of qubits used in the module - size_t numQubits{0}; - /// Number of measurement results stored in the module - size_t numResults{0}; - /// Whether the module uses dynamic qubit management - bool useDynamicQubit{false}; - /// Whether the module uses dynamic result management - bool useDynamicResult{false}; - /// Whether the module uses arrays - bool useArrays{false}; - /// Whether the module uses backward branching (0 = none, 1 = iteration based, - /// 2 = condition based, 3 = both) - int backwardsBranching{0}; - /// Whether the module uses Adaptive Profile - bool useAdaptive{false}; -}; - -} // namespace mlir::qir diff --git a/mlir/include/mlir/Dialect/QIR/Utils/QIRUtils.h b/mlir/include/mlir/Dialect/QIR/Utils/QIRUtils.h index 81a880192e..de6d055ec6 100644 --- a/mlir/include/mlir/Dialect/QIR/Utils/QIRUtils.h +++ b/mlir/include/mlir/Dialect/QIR/Utils/QIRUtils.h @@ -10,8 +10,6 @@ #pragma once -#include "mlir/Dialect/QIR/Utils/QIRMetadata.h" - #include #include #include @@ -169,29 +167,6 @@ DEFINE_GETTER(XXMINUSYY) */ LLVM::LLVMFuncOp getMainFunction(Operation* op); -/** - * @brief Set QIR base profile metadata attributes on the main function - * - * @details - * Adds the required metadata attributes for QIR base profile compliance: - * - `entry_point`: Marks the main entry point function - * - `output_labeling_schema`: labeled - * - `qir_profiles`: base_profile - * - `required_num_qubits`: Number of qubits used - * - `required_num_results`: Number of measurement results - * - `qir_major_version`: 2 - * - `qir_minor_version`: 1 - * - `dynamic_qubit_management`: true/false - * - `dynamic_result_management`: true/false - * - * These attributes are required by the QIR specification and inform QIR - * consumers about the module's resource requirements and capabilities. - * - * @param main The main LLVM function to annotate - * @param metadata The QIR metadata containing qubit/result counts - */ -void setQIRAttributes(LLVM::LLVMFuncOp& main, const QIRMetadata& metadata); - /** * @brief Get or create a QIR function declaration * diff --git a/mlir/include/mlir/Support/Passes.h b/mlir/include/mlir/Support/Passes.h index 9a7ec7a2e8..0f10e92d44 100644 --- a/mlir/include/mlir/Support/Passes.h +++ b/mlir/include/mlir/Support/Passes.h @@ -17,6 +17,14 @@ class ModuleOp; class PassManager; } // namespace mlir +/** + * @brief Populate the pass manager and run it on the module. + */ +mlir::LogicalResult +runWithPassManager(mlir::ModuleOp module, + mlir::function_ref populatePasses, + mlir::StringRef errorMessage); + /** * @brief Populate a QC-oriented cleanup pipeline on the given pass manager. * @details Adds generic cleanup and QC qubit-register shrinking. @@ -31,9 +39,10 @@ void populateQCOCleanupPipeline(mlir::PassManager& pm); /** * @brief Populate a QIR-oriented cleanup pipeline on the given pass manager. - * @details Adds generic cleanup and QIR-specific simplifications. + * @details Adds generic cleanup and QIR-specific simplifications. Updates the + * meta data accordingly. */ -void populateQIRCleanupPipeline(mlir::PassManager& pm); +void populateQIRCleanupPipeline(mlir::PassManager& pm, bool useAdaptive); /** * @brief Run the QC-oriented cleanup pipeline on a module. @@ -48,4 +57,5 @@ void populateQIRCleanupPipeline(mlir::PassManager& pm); /** * @brief Run the QIR-oriented cleanup pipeline on a module. */ -[[nodiscard]] mlir::LogicalResult runQIRCleanupPipeline(mlir::ModuleOp module); +[[nodiscard]] mlir::LogicalResult runQIRCleanupPipeline(mlir::ModuleOp module, + bool useAdaptive); diff --git a/mlir/lib/Compiler/CMakeLists.txt b/mlir/lib/Compiler/CMakeLists.txt index b36ef2ff3b..735c12d6fd 100644 --- a/mlir/lib/Compiler/CMakeLists.txt +++ b/mlir/lib/Compiler/CMakeLists.txt @@ -19,8 +19,8 @@ add_mlir_library( MLIRTransformUtils MLIRQCToQCO MLIRQCOToQC - MLIRQCToQIRAdaptive MLIRQCToQIRBase + MLIRQCToQIRAdaptive MLIRQCOTransforms MQT::MLIRSupport) diff --git a/mlir/lib/Compiler/CompilerPipeline.cpp b/mlir/lib/Compiler/CompilerPipeline.cpp index b3f6b9daef..821ccda2ee 100644 --- a/mlir/lib/Compiler/CompilerPipeline.cpp +++ b/mlir/lib/Compiler/CompilerPipeline.cpp @@ -199,18 +199,15 @@ QuantumCompilerPipeline::runPipeline(ModuleOp module, } // Stage 9: QC-to-QIR conversion (optional) if (convertToQIR) { - auto addConversionPass = [&](PassManager& pm) { - if (config_.convertToQIRBase) { - pm.addPass(createQCToQIRBase()); - } else { - pm.addPass(createQCToQIRAdaptive()); - } - }; - - if (failed(runStage(addConversionPass))) { + if (failed(runStage([&](PassManager& pm) { + if (config_.convertToQIRAdaptive) { + pm.addPass(createQCToQIRAdaptive()); + } else { + pm.addPass(createQCToQIRBase()); + } + }))) { return failure(); } - if (record != nullptr && config_.recordIntermediates) { record->afterQIRConversion = captureIR(module); if (config_.printIRAfterAllStages) { @@ -219,8 +216,9 @@ QuantumCompilerPipeline::runPipeline(ModuleOp module, } } // Stage 10: QIR cleanup (optional) - if (failed(runStage( - [&](PassManager& pm) { populateQIRCleanupPipeline(pm); }))) { + if (failed(runStage([&](PassManager& pm) { + populateQIRCleanupPipeline(pm, config_.convertToQIRAdaptive); + }))) { return failure(); } if (record != nullptr && config_.recordIntermediates) { diff --git a/mlir/lib/Conversion/QCToQIR/QIRAdaptive/QCToQIRAdaptive.cpp b/mlir/lib/Conversion/QCToQIR/QIRAdaptive/QCToQIRAdaptive.cpp index bdd7ead97c..9b03e4d722 100644 --- a/mlir/lib/Conversion/QCToQIR/QIRAdaptive/QCToQIRAdaptive.cpp +++ b/mlir/lib/Conversion/QCToQIR/QIRAdaptive/QCToQIRAdaptive.cpp @@ -35,7 +35,6 @@ #include #include #include -#include #include #include @@ -84,9 +83,6 @@ struct ConvertMemRefAllocOp final } auto& state = getState(); - state.useDynamicQubit = true; - state.useArrays = true; - auto* ctx = getContext(); auto ptrType = LLVM::LLVMPointerType::get(ctx); @@ -235,7 +231,6 @@ struct ConvertQCAllocOp final : StatefulOpConversionPattern { op.getOperation()))) { return failure(); } - state.useDynamicQubit = true; auto* ctx = getContext(); auto ptrType = LLVM::LLVMPointerType::get(ctx); @@ -363,8 +358,6 @@ struct ConvertQCMeasureOp final : StatefulOpConversionPattern { matchAndRewrite(MeasureOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { auto& state = getState(); - state.useDynamicResult = true; - auto& resultArrays = state.resultArrays; auto& loadedResults = state.loadedResults; auto& resultPtrs = state.resultPtrs; @@ -381,7 +374,6 @@ struct ConvertQCMeasureOp final : StatefulOpConversionPattern { // Get result pointer Value result; if (op.getRegisterName() && op.getRegisterSize() && op.getRegisterIndex()) { - state.useArrays = true; const auto registerName = op.getRegisterName().value(); const auto registerSize = static_cast(op.getRegisterSize().value()); @@ -425,7 +417,6 @@ struct ConvertQCMeasureOp final : StatefulOpConversionPattern { rewriter.setInsertionPoint(state.entryBlock->getTerminator()); result = createPointerFromIndex(rewriter, op.getLoc(), resultPtrs.size()); resultPtrs.try_emplace(resultPtrs.size(), result); - state.numResults++; } rewriter.restoreInsertionPoint(savedInsertionPoint); @@ -573,22 +564,6 @@ struct QCToQIRAdaptive final : impl::QCToQIRAdaptiveBase { } } - /** - * @brief Iterates through the module to find any scf.while or scf.for - * operation to set the backward branching flag before they are converted to - * cf operations. - */ - static void setSCFFlags(Operation* op, LoweringState* state) { - op->walk([&](scf::ForOp) { - state->backwardsBranching += 1; - return WalkResult::interrupt(); - }); - op->walk([&](scf::WhileOp) { - state->backwardsBranching += 2; - return WalkResult::interrupt(); - }); - } - protected: /** * @brief Executes the QC to QIR conversion pass @@ -614,15 +589,11 @@ struct QCToQIRAdaptive final : impl::QCToQIRAdaptiveBase { * Convert QC dialect operations and memref operations to QIR calls and add * output recording to the output block. * - * **Stage 6: QIR Attributes** - * Add QIR Profile metadata to the main function, including qubit/result - * counts and version information. - * - * **Stage 7: Standard dialects to LLVM** + * **Stage 6: Standard dialects to LLVM** * Convert arith and control flow dialects to LLVM (for index arithmetic and * function control flow). * - * **Stage 8: Reconcile casts** + * **Stage 7: Reconcile casts** * Clean up any unrealized cast operations introduced during type * conversion. */ @@ -632,15 +603,11 @@ struct QCToQIRAdaptive final : impl::QCToQIRAdaptiveBase { ConversionTarget target(*ctx); QCToQIRTypeConverter typeConverter(ctx); LoweringState state; - state.useAdaptive = true; target.addLegalDialect(); // Stage 1: Convert scf dialect to cf { - // Find the required flags before the scf operations are converted - setSCFFlags(moduleOp, &state); - RewritePatternSet scfPatterns(ctx); target.addIllegalDialect(); target.addLegalDialect(); @@ -696,10 +663,7 @@ struct QCToQIRAdaptive final : impl::QCToQIRAdaptiveBase { releaseResults(main, ctx, &state); } - // Stage 6: Set QIR metadata attributes - setQIRAttributes(main, state); - - // Stage 7: Convert standard dialects to LLVM + // Stage 6: Convert standard dialects to LLVM { RewritePatternSet stdPatterns(ctx); target.addIllegalDialect(); @@ -716,7 +680,7 @@ struct QCToQIRAdaptive final : impl::QCToQIRAdaptiveBase { } } - // Stage 8: Reconcile unrealized casts + // Stage 7: Reconcile unrealized casts PassManager passManager(ctx); passManager.addPass(createReconcileUnrealizedCastsPass()); if (passManager.run(moduleOp).failed()) { diff --git a/mlir/lib/Conversion/QCToQIR/QIRBase/QCToQIRBase.cpp b/mlir/lib/Conversion/QCToQIR/QIRBase/QCToQIRBase.cpp index 74638efaae..c7043422f1 100644 --- a/mlir/lib/Conversion/QCToQIR/QIRBase/QCToQIRBase.cpp +++ b/mlir/lib/Conversion/QCToQIR/QIRBase/QCToQIRBase.cpp @@ -102,10 +102,10 @@ struct ConvertMemRefLoadOp final : StatefulOpConversionPattern { // Switch to entry block rewriter.setInsertionPoint(state.entryBlock->getTerminator()); + auto nqubits = state.staticQubits.size(); auto qubit = createPointerFromIndex(rewriter, op.getLoc(), - static_cast(state.numQubits)); - state.staticQubits.try_emplace(static_cast(state.numQubits++), - qubit); + static_cast(nqubits)); + state.staticQubits.try_emplace(static_cast(nqubits), qubit); rewriter.replaceOp(op, qubit); return success(); @@ -157,10 +157,10 @@ struct ConvertQCAllocOp final : StatefulOpConversionPattern { rewriter.setInsertionPoint(state.entryBlock->getTerminator()); + const auto nqubits = state.staticQubits.size(); auto qubit = createPointerFromIndex(rewriter, op.getLoc(), - static_cast(state.numQubits)); - state.staticQubits.try_emplace(static_cast(state.numQubits++), - qubit); + static_cast(nqubits)); + state.staticQubits.try_emplace(static_cast(nqubits), qubit); rewriter.replaceOp(op, qubit); return success(); @@ -221,20 +221,17 @@ struct ConvertQCMeasureOp final : StatefulOpConversionPattern { Value result; int64_t resultIndex = 0; + const auto nresults = resultPtrs.size(); if (op.getRegisterIndex() && op.getRegisterName() && op.getRegisterSize()) { const auto registerName = op.getRegisterName().value(); - const auto registerSize = op.getRegisterSize().value(); const auto registerIndex = op.getRegisterIndex().value(); // Assign a base offset to this register if not yet seen - if (!state.registerOffsets.contains(registerName)) { - state.registerOffsets.try_emplace(registerName, state.numResults); - state.numResults += registerSize; - } - resultIndex = state.registerOffsets[registerName] + - static_cast(registerIndex); + const auto [it, _] = + state.registerOffsets.try_emplace(registerName, nresults); + resultIndex = it->second + static_cast(registerIndex); } else { - resultIndex = static_cast(state.numResults); + resultIndex = static_cast(nresults); } if (resultPtrs.contains(resultIndex)) { @@ -244,9 +241,6 @@ struct ConvertQCMeasureOp final : StatefulOpConversionPattern { rewriter.setInsertionPoint(state.entryBlock->getTerminator()); result = createPointerFromIndex(rewriter, op.getLoc(), resultIndex); resultPtrs.try_emplace(resultIndex, result); - if (std::cmp_greater_equal(resultIndex, state.numResults)) { - state.numResults = resultIndex + 1; - } } // Switch to measurements block @@ -380,7 +374,7 @@ struct QCToQIRBase final : impl::QCToQIRBaseBase { * @brief Executes the QC to QIR conversion pass * * @details - * Performs the conversion in seven stages: + * Performs the conversion in six stages: * * **Stage 1: Func to LLVM** * Convert func dialect operations (main function) to LLVM dialect @@ -397,15 +391,11 @@ struct QCToQIRBase final : impl::QCToQIRBaseBase { * Convert QC dialect operations to QIR calls and add output recording to the * output block. * - * **Stage 5: QIR attributes** - * Add QIR base profile metadata to the main function, including qubit/result - * counts and version information. - * - * **Stage 6: Standard dialects to LLVM** + * **Stage 5: Standard dialects to LLVM** * Convert arith and control flow dialects to LLVM (for index arithmetic and * function control flow). * - * **Stage 7: Reconcile casts** + * **Stage 6: Reconcile casts** * Clean up any unrealized cast operations introduced during type conversion. */ void runOnOperation() override { @@ -460,10 +450,7 @@ struct QCToQIRBase final : impl::QCToQIRBaseBase { addOutputRecording(main, ctx, state); } - // Stage 5: Set QIR metadata attributes - setQIRAttributes(main, state); - - // Stage 6: Convert standard dialects to LLVM + // Stage 5: Convert standard dialects to LLVM { RewritePatternSet stdPatterns(ctx); target.addIllegalDialect(); @@ -480,7 +467,7 @@ struct QCToQIRBase final : impl::QCToQIRBaseBase { } } - // Stage 7: Reconcile unrealized casts + // Stage 6: Reconcile unrealized casts PassManager passManager(ctx); passManager.addPass(createReconcileUnrealizedCastsPass()); if (passManager.run(moduleOp).failed()) { diff --git a/mlir/lib/Conversion/QCToQIR/QIRCommon/QIRCommon.cpp b/mlir/lib/Conversion/QCToQIR/QIRCommon/QIRCommon.cpp index f197002ddd..aea4ae0a02 100644 --- a/mlir/lib/Conversion/QCToQIR/QIRCommon/QIRCommon.cpp +++ b/mlir/lib/Conversion/QCToQIR/QIRCommon/QIRCommon.cpp @@ -297,11 +297,6 @@ struct ConvertQCStaticOp final : StatefulOpConversionPattern { } rewriter.replaceOp(op, qubit); - // Track maximum qubit index - if (std::cmp_greater_equal(index, state.numQubits)) { - state.numQubits = index + 1; - } - return success(); } }; diff --git a/mlir/lib/Dialect/QIR/Builder/QIRProgramBuilder.cpp b/mlir/lib/Dialect/QIR/Builder/QIRProgramBuilder.cpp index 771e62b926..f775264568 100644 --- a/mlir/lib/Dialect/QIR/Builder/QIRProgramBuilder.cpp +++ b/mlir/lib/Dialect/QIR/Builder/QIRProgramBuilder.cpp @@ -10,7 +10,9 @@ #include "mlir/Dialect/QIR/Builder/QIRProgramBuilder.h" +#include "mlir/Dialect/QIR/Transforms/Passes.h" #include "mlir/Dialect/QIR/Utils/QIRUtils.h" +#include "mlir/Support/Passes.h" #include #include @@ -28,6 +30,7 @@ #include #include #include +#include #include #include @@ -130,7 +133,6 @@ Value QIRProgramBuilder::allocQubit() { Value qubit; if (profile == Profile::Adaptive) { ensureAllocationMode(AllocationMode::Dynamic); - metadata_.useDynamicQubit = true; auto fnSig = LLVM::LLVMFunctionType::get(ptrType, {ptrType}); auto fnDec = @@ -139,7 +141,7 @@ Value QIRProgramBuilder::allocQubit() { auto zero = LLVM::ZeroOp::create(*this, ptrType); qubit = LLVM::CallOp::create(*this, fnDec, zero.getResult()).getResult(); } else { - qubit = staticQubit(static_cast(metadata_.numQubits)); + qubit = staticQubit(static_cast(numQubits)); } qubits.insert(qubit); @@ -169,8 +171,8 @@ Value QIRProgramBuilder::staticQubit(const int64_t index) { } // Update qubit count - if (std::cmp_greater_equal(index, metadata_.numQubits)) { - metadata_.numQubits = static_cast(index) + 1; + if (std::cmp_greater_equal(index, numQubits)) { + numQubits = static_cast(index) + 1; } return qubit; @@ -198,8 +200,8 @@ Value QIRProgramBuilder::staticResult(const int64_t index) { } // Update result count - if (std::cmp_greater_equal(index, metadata_.numResults)) { - metadata_.numResults = static_cast(index) + 1; + if (std::cmp_greater_equal(index, numResults)) { + numResults = static_cast(index) + 1; } return result; @@ -228,8 +230,6 @@ QIRProgramBuilder::allocQubitRegister(const int64_t size) { if (profile == Profile::Adaptive) { // Create a dynamic qubit array and load the qubits in the Adaptive Profile ensureAllocationMode(AllocationMode::Dynamic); - metadata_.useArrays = true; - metadata_.useDynamicQubit = true; auto allocFnSignature = LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(getContext()), @@ -257,7 +257,7 @@ QIRProgramBuilder::allocQubitRegister(const int64_t size) { } else { // Create static qubits in the Base Profile for (int64_t i = 0; i < size; ++i) { - auto qubit = staticQubit(static_cast(metadata_.numQubits)); + auto qubit = staticQubit(static_cast(numQubits)); qubits.push_back(qubit); } } @@ -317,8 +317,6 @@ QIRProgramBuilder::allocClassicalBitRegister(const int64_t size, if (profile == Profile::Adaptive) { // Create a dynamic result array for the Adaptive Profile - metadata_.useDynamicResult = true; - metadata_.useArrays = true; auto fnSig = LLVM::LLVMFunctionType::get(voidType, {getI64Type(), ptrType, ptrType}); @@ -343,7 +341,7 @@ QIRProgramBuilder::allocClassicalBitRegister(const int64_t size, } else { // Use static results in the Base Profile for (int64_t i = 0; i < size; ++i) { - auto result = staticResult(static_cast(metadata_.numResults)); + auto result = staticResult(static_cast(numResults)); loadedResults.try_emplace({stringSaver.save(name), i}, result); } } @@ -392,9 +390,7 @@ Value QIRProgramBuilder::measure(Value qubit, const Bit& bit) { llvm::reportFatalUsageError("Bit does not belong to a result pointer"); } auto result = it->second; - if (profile == Profile::Adaptive) { - metadata_.useDynamicResult = true; - } else { + if (profile != Profile::Adaptive) { setInsertionPoint(measurementsBlock->getTerminator()); } @@ -726,11 +722,6 @@ QIRProgramBuilder::scfFor(const std::variant& lowerbound, "Adaptive Profile is selected."); } - int& backwardsBranchingFlag = metadata_.backwardsBranching; - if (backwardsBranchingFlag != 1 && backwardsBranchingFlag != 3) { - backwardsBranchingFlag += 1; - } - auto loc = getLoc(); auto lb = resolveIntVariant(lowerbound); auto ub = resolveIntVariant(upperbound); @@ -853,11 +844,6 @@ QIRProgramBuilder::scfWhile(const function_ref& beforeBody, "Adaptive Profile is selected."); } - int& backwardsBranchingFlag = metadata_.backwardsBranching; - if (backwardsBranchingFlag != 2 && backwardsBranchingFlag != 3) { - backwardsBranchingFlag += 2; - } - auto* currentBlock = getInsertionBlock(); // Build the blocks auto* beforeBlock = createBlock(currentBlock->getParent(), @@ -985,11 +971,12 @@ OwningOpRef QIRProgramBuilder::finalize() { // Save current insertion point const InsertionGuard guard(*this); + const bool isAdaptive = (profile == Profile::Adaptive); // Release resources in output block setInsertionPoint(outputBlock->getTerminator()); - if (profile == Profile::Adaptive) { + if (isAdaptive) { for (auto qubit : qubits) { auto sig = LLVM::LLVMFunctionType::get(voidType, {ptrType}); auto dec = @@ -1009,7 +996,7 @@ OwningOpRef QIRProgramBuilder::finalize() { // Generate output recording in output block generateOutputRecording(); - if (profile == Profile::Adaptive) { + if (isAdaptive) { for (auto& [_, ptr] : resultPtrs) { auto sig = LLVM::LLVMFunctionType::get(voidType, {ptrType}); auto dec = getOrCreateFunctionDeclaration(*this, module, @@ -1026,13 +1013,17 @@ OwningOpRef QIRProgramBuilder::finalize() { } } - auto mainFuncOp = cast(mainFunc); - metadata_.useAdaptive = profile == Profile::Adaptive; - setQIRAttributes(mainFuncOp, metadata_); - + // Attach attributes + auto m = cast(module); + std::ignore = runWithPassManager( + m, + [&](PassManager& pm) { + pm.addPass(qir::createQIRSetAttributesAndMetadata({isAdaptive})); + }, + "Failed to attach attributes"); isFinalized = true; - return cast(module); + return m; } OwningOpRef QIRProgramBuilder::build( diff --git a/mlir/lib/Dialect/QIR/Transforms/AttachQIRAttributes.cpp b/mlir/lib/Dialect/QIR/Transforms/AttachQIRAttributes.cpp new file mode 100644 index 0000000000..749762ef79 --- /dev/null +++ b/mlir/lib/Dialect/QIR/Transforms/AttachQIRAttributes.cpp @@ -0,0 +1,399 @@ +/* + * Copyright (c) 2023 - 2026 Chair for Design Automation, TUM + * Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH + * All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License + */ + +#include "mlir/Dialect/QIR/Transforms/Passes.h" +#include "mlir/Dialect/QIR/Utils/QIRUtils.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace mlir::qir { +#define GEN_PASS_DEF_QIRSETATTRIBUTESANDMETADATA +#include "mlir/Dialect/QIR/Transforms/Passes.h.inc" + +namespace { + +/// State object for tracking QIR metadata during conversion +struct Metadata { + /// Number of qubits used in the module + size_t numQubits{0}; + /// Number of measurement results stored in the module + size_t numResults{0}; + /// Whether the module uses dynamic qubit management + bool useDynamicQubit{false}; + /// Whether the module uses dynamic result management + bool useDynamicResult{false}; + /// Whether the module uses arrays + bool useArrays{false}; + /// Whether the module uses backward branching (0 = none, 1 = iteration based, + /// 2 = condition based, 3 = both) + int backwardsBranching{0}; +}; + +/** + * @brief Attaches the required attributes to the function marked as + * entry_point. + */ +struct QIRSetAttributesAndMetadata final + : impl::QIRSetAttributesAndMetadataBase { + using QIRSetAttributesAndMetadataBase::QIRSetAttributesAndMetadataBase; + +protected: + void runOnOperation() override { + IRRewriter rewriter(&getContext()); + auto main = getMainFunction(getOperation()); + if (!main) { + return; + } + setMetadata(main, useAdaptive ? getAdaptive(main) : getBase(main), + rewriter); + } + +private: + /// Clear and set QIR base profile metadata. + /// + /// Adds the required metadata attributes for QIR base profile compliance: + /// - `entry_point`: Marks the main entry point function + /// - `output_labeling_schema`: labeled + /// - `qir_profiles`: base_profile + /// - `required_num_qubits`: Number of qubits used + /// - `required_num_results`: Number of measurement results + /// - `qir_major_version`: 2 + /// - `qir_minor_version`: 1 + /// - `dynamic_qubit_management`: true/false + /// - `dynamic_result_management`: true/false + /// + /// These attributes are required by the QIR specification and inform QIR + /// consumers about the module's resource requirements and capabilities. + void setMetadata(LLVM::LLVMFuncOp& main, const Metadata& metadata, + IRRewriter& rewriter) { + auto m = getOperation(); + const auto createFlag = [&](LLVM::ModFlagBehavior behavior, StringRef name, + int32_t val) { + return LLVM::ModuleFlagAttr::get(m->getContext(), behavior, + rewriter.getStringAttr(name), + rewriter.getI32IntegerAttr(val)); + }; + + const SmallVector attributes{ + rewriter.getStringAttr("entry_point"), + rewriter.getStrArrayAttr({"output_labeling_schema", "labeled"}), + rewriter.getStrArrayAttr({"qir_profiles", useAdaptive + ? "adaptive_profile" + : "base_profile"}), + rewriter.getStrArrayAttr( + {"required_num_qubits", std::to_string(metadata.numQubits)}), + rewriter.getStrArrayAttr( + {"required_num_results", std::to_string(metadata.numResults)})}; + + main->setAttr("passthrough", rewriter.getArrayAttr(attributes)); + + rewriter.setInsertionPointToEnd(m.getBody()); + + SmallVector flags{ + createFlag(LLVM::ModFlagBehavior::Error, "qir_major_version", 2), + createFlag(LLVM::ModFlagBehavior::Max, "qir_minor_version", 1), + createFlag(LLVM::ModFlagBehavior::Error, "dynamic_qubit_management", + static_cast(metadata.useDynamicQubit)), + createFlag(LLVM::ModFlagBehavior::Error, "dynamic_result_management", + static_cast(metadata.useDynamicResult))}; + + if (useAdaptive) { + flags.emplace_back(createFlag(LLVM::ModFlagBehavior::Error, + "backwards_branching", + metadata.backwardsBranching)); + flags.emplace_back(createFlag(LLVM::ModFlagBehavior::Error, "arrays", + static_cast(metadata.useArrays))); + } + + removeExistingModuleFlags(m, rewriter); + LLVM::ModuleFlagsOp::create(rewriter, m.getLoc(), + rewriter.getArrayAttr(flags)); + } + + /// Remove existing module flag operations from module. + /// Note that this might also erase non-QIR module flag operations, but for + /// now, we assume that there are no others. + static void removeExistingModuleFlags(ModuleOp m, IRRewriter& rewriter) { + SmallVector flagOps; + m->walk([&](LLVM::ModuleFlagsOp op) { flagOps.emplace_back(op); }); + llvm::for_each(flagOps, [&](Operation* op) { rewriter.eraseOp(op); }); + } + + /// Count the number of uniquely indexed qubit pointers. + /// Assumes that qubits are constant integers that are converted to + /// an integer pointer and then used in (at least) one quantum instruction. + static size_t getNumQubits(LLVM::LLVMFuncOp& main) { + static constexpr StringRef QIS_PREFIX = "__quantum__qis"; + + DenseSet seen; + main->walk([&](LLVM::ConstantOp constOp) { + if (constOp.use_empty()) { + return; + } + + const auto intAttr = dyn_cast(constOp.getValue()); + if (!intAttr) { + return; + } + + if (!intAttr.getType().isInteger()) { // Not a ": index". + return; + } + + const auto userIt = + llvm::find_if(constOp->getUsers(), [](Operation* user) { + return isa(user); + }); + if (userIt == constOp->user_end()) { + return; + } + + const auto toPtrOp = cast(*userIt); + const auto callIt = + llvm::find_if(toPtrOp->getUses(), [](OpOperand& operand) { + auto callOp = dyn_cast(operand.getOwner()); + if (!callOp) { + return false; + } + + auto callee = callOp.getCallee(); + if (!callee.has_value()) { + return false; + } + + if (*callee == QIR_MEASURE) { + + // The following assumes that the first argument of a + // measurement call is the qubit. This may (or may not) hold in + // the future. + + return operand.getOperandNumber() == 0; + } + + return callee->starts_with(QIS_PREFIX); + }); + if (callIt == toPtrOp->use_end()) { + return; + } + + // The set ensures that we don't insert the same index multiple times. + seen.insert(intAttr.getValue()); + }); + + return seen.size(); + } + + /// Count the number of uniquely indexed result_record_output statements. + static size_t getNumResults(LLVM::LLVMFuncOp& main) { + DenseSet seen; + main->walk([&](LLVM::CallOp callOp) { + if (!callOp.getCallee()) { + return; + } + + if (*callOp.getCallee() != QIR_RECORD_OUTPUT) { + return; + } + + const auto operand = callOp->getOperand(0); + auto toPtrOp = dyn_cast(operand.getDefiningOp()); + if (!toPtrOp) { + return; + } + + const auto arg = toPtrOp.getArg(); + auto constOp = dyn_cast(arg.getDefiningOp()); + if (!constOp) { + return; + } + + const auto intAttr = dyn_cast(constOp.getValue()); + if (!intAttr) { + return; + } + + // The set ensures that we don't insert the same index multiple times. + seen.insert(intAttr.getValue()); + }); + + return seen.size(); + } + + /// Determine whether a loop (as a set of blocks) is an iterative loop (true) + /// or a conditionally terminated loop (false). + static bool classifyLoop(const SmallPtrSet& loop) { + for (Block* block : loop) { + Operation* terminator = block->getTerminator(); + assert(terminator != nullptr); + + if (auto condBrOp = dyn_cast(terminator)) { + auto condition = condBrOp.getCondition(); + + if (isa(condition)) { // Ensure that there is a def-op. + return true; + } + + auto callOp = dyn_cast(condition.getDefiningOp()); + + // If the condition is not produced by a measurement call, we + // consider it a basic loop. + if (!callOp || !callOp.getCallee()) { + return true; + } + + // If the condition has been produced by a measurement call + // (e.g. a until-zero-measurement loop), and breaks outside the loop, + // we found a "conditionally terminating loop". + if (*callOp.getCallee() == QIR_READ_RESULT && + (!loop.contains(condBrOp.getTrueDest()) || + !loop.contains(condBrOp.getFalseDest()))) { + return false; + } + + // Unseen edge case (so far): The condition of the terminator + // operation is produced by a function call, which isn't a + // measurement. + return true; + } + } + } + + /// Return pair of booleans, indicating whether the entry point uses + /// iterations = [0] or conditionally terminated loops = [1]. + static std::pair + usesBackwardsBranching(LLVM::LLVMFuncOp& main, const DominanceInfo& domInfo) { + bool useIteration{false}; + bool useCondTerm{false}; + + SmallVector worklist; + + for (Block& block : main.getBlocks()) { + for (Block* successor : block.getSuccessors()) { + if (domInfo.dominates(successor, &block)) { // Back edge. + Block* header = successor; + Block* tail = █ + + SmallPtrSet loop{header}; + if (header != tail) { + worklist.push_back(tail); + } + + while (!worklist.empty()) { + Block* curr = worklist.pop_back_val(); + for (Block* pred : curr->getPredecessors()) { + if (loop.insert(pred).second) { + worklist.push_back(pred); + } + } + } + + if (classifyLoop(loop)) { + useIteration |= true; + } else { + useCondTerm |= true; + } + + loop.clear(); + } + } + } + + return std::make_pair(useIteration, useCondTerm); + } + + /// Return triple of booleans, indicating whether the entry point uses + /// dynamic qubits = [0], dynamic results = [1], or dynamic arrays = [2]. + static std::tuple usesDynamic(LLVM::LLVMFuncOp& main) { + bool useDynamicQubit{false}; + bool useDynamicResult{false}; + bool useArrays{false}; + + main->walk([&](LLVM::CallOp callOp) { + if (!callOp.getCallee()) { + return; + } + + const auto name = *callOp.getCallee(); + if (name == QIR_QUBIT_ALLOC) { + useDynamicQubit = true; + } else if (name == QIR_RESULT_ALLOC) { + useDynamicResult = true; + } else if (name == QIR_QUBIT_ARRAY_ALLOC) { + useDynamicQubit = true; + useArrays = true; + } else if (name == QIR_RESULT_ARRAY_ALLOC) { + useDynamicResult = true; + useArrays = true; + } + }); + + return std::make_tuple(useDynamicQubit, useDynamicResult, useArrays); + } + + /// Return the metadata for a QIR base profile compliant program. + static Metadata getBase(LLVM::LLVMFuncOp& main) { + return {.numQubits = getNumQubits(main), + .numResults = getNumResults(main), + .useDynamicQubit = false, + .useDynamicResult = false, + .useArrays = false, + .backwardsBranching = 0}; + } + + /// Return the metadata for a QIR adaptive profile compliant program. + Metadata getAdaptive(LLVM::LLVMFuncOp& main) { + const auto& domInfo = getAnalysis(); + const auto [useIteration, useCondTerm] = + usesBackwardsBranching(main, domInfo); + const auto [useDynamicQubit, useDynamicResult, useArrays] = + usesDynamic(main); + + Metadata md; + md.useDynamicQubit = useDynamicQubit; + md.useDynamicResult = useDynamicResult; + md.useArrays = useArrays; + + if (!useDynamicQubit) { + md.numQubits = getNumQubits(main); + } + + if (!useDynamicResult) { + md.numResults = getNumResults(main); + } + + if (useIteration) { + md.backwardsBranching = useCondTerm ? 3 : 1; + } else if (useCondTerm) { + md.backwardsBranching = 2; + } + + return md; + } +}; +} // namespace +} // namespace mlir::qir diff --git a/mlir/lib/Dialect/QIR/Utils/QIRUtils.cpp b/mlir/lib/Dialect/QIR/Utils/QIRUtils.cpp index 309a5e9652..7a0893fc2e 100644 --- a/mlir/lib/Dialect/QIR/Utils/QIRUtils.cpp +++ b/mlir/lib/Dialect/QIR/Utils/QIRUtils.cpp @@ -10,8 +10,6 @@ #include "mlir/Dialect/QIR/Utils/QIRUtils.h" -#include "mlir/Dialect/QIR/Utils/QIRMetadata.h" - #include #include #include @@ -26,7 +24,6 @@ #include #include -#include namespace mlir::qir { @@ -55,65 +52,6 @@ LLVM::LLVMFuncOp getMainFunction(Operation* op) { return nullptr; } -void setQIRAttributes(LLVM::LLVMFuncOp& main, const QIRMetadata& metadata) { - auto module = main->getParentOfType(); - if (metadata.useDynamicQubit && metadata.numQubits != 0) { - llvm::reportFatalUsageError( - "Cannot use dynamic qubit allocation if static qubits are allocated"); - } - - OpBuilder builder(main.getBody()); - SmallVector attributes; - - // Core QIR attributes - attributes.emplace_back(builder.getStringAttr("entry_point")); - attributes.emplace_back( - builder.getStrArrayAttr({"output_labeling_schema", "labeled"})); - attributes.emplace_back(builder.getStrArrayAttr( - {"qir_profiles", - metadata.useAdaptive ? "adaptive_profile" : "base_profile"})); - - // Resource requirements - attributes.emplace_back(builder.getStrArrayAttr( - {"required_num_qubits", std::to_string(metadata.numQubits)})); - attributes.emplace_back(builder.getStrArrayAttr( - {"required_num_results", std::to_string(metadata.numResults)})); - - main->setAttr("passthrough", builder.getArrayAttr(attributes)); - - builder.setInsertionPointToEnd(module.getBody()); - - auto createFlag = [&](LLVM::ModFlagBehavior behavior, StringRef name, - int32_t val) { - return LLVM::ModuleFlagAttr::get(module->getContext(), behavior, - builder.getStringAttr(name), - builder.getI32IntegerAttr(val)); - }; - - SmallVector flags; - - flags.push_back( - createFlag(LLVM::ModFlagBehavior::Error, "qir_major_version", 2)); - flags.push_back( - createFlag(LLVM::ModFlagBehavior::Max, "qir_minor_version", 1)); - flags.push_back(createFlag(LLVM::ModFlagBehavior::Error, - "dynamic_qubit_management", - static_cast(metadata.useDynamicQubit))); - flags.push_back(createFlag(LLVM::ModFlagBehavior::Error, - "dynamic_result_management", - static_cast(metadata.useDynamicResult))); - if (metadata.useAdaptive) { - flags.push_back(createFlag(LLVM::ModFlagBehavior::Error, - "backwards_branching", - metadata.backwardsBranching)); - flags.push_back(createFlag(LLVM::ModFlagBehavior::Error, "arrays", - static_cast(metadata.useArrays))); - } - - LLVM::ModuleFlagsOp::create(builder, module.getLoc(), - builder.getArrayAttr(flags)); -} - LLVM::LLVMFuncOp getOrCreateFunctionDeclaration(OpBuilder& builder, Operation* op, StringRef fnName, Type fnType) { diff --git a/mlir/lib/Support/IRVerification.cpp b/mlir/lib/Support/IRVerification.cpp index a1700fa319..79ec785ba7 100644 --- a/mlir/lib/Support/IRVerification.cpp +++ b/mlir/lib/Support/IRVerification.cpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -308,9 +309,24 @@ static bool compareAttributes(Attribute lhs, Attribute rhs) { !strAttrB || strAttrA.getValue() != strAttrB.getValue()) { return false; } - } else if (auto symbolRefAttrA = - llvm::dyn_cast(lhs)) { - auto symbolRefAttrB = llvm::dyn_cast(rhs); + } else if (auto arrayAttrA = llvm::dyn_cast(lhs)) { + auto arrayAttrB = llvm::dyn_cast(rhs); + if (!arrayAttrB) { + return false; + } + if (arrayAttrA.size() != arrayAttrB.size()) { + return false; + } + + for (const auto& [subAttrA, subAttrB] : + llvm::zip_equal(arrayAttrA, arrayAttrB)) { + if (!compareAttributes(subAttrA, subAttrB)) { + return false; + } + } + + } else if (auto symbolRefAttrA = dyn_cast(lhs)) { + auto symbolRefAttrB = dyn_cast(rhs); if (!symbolRefAttrB) { return false; } @@ -318,6 +334,44 @@ static bool compareAttributes(Attribute lhs, Attribute rhs) { if (symbolRefAttrA.getValue() != symbolRefAttrB.getValue()) { return false; } + } else if (auto tailCallAttrA = dyn_cast(lhs)) { + auto tailCallAttrB = dyn_cast(rhs); + if (!tailCallAttrB) { + return false; + } + + if (tailCallAttrA.getTailCallKind() != tailCallAttrB.getTailCallKind()) { + return false; + } + } else if (auto fastMathAttrA = dyn_cast(lhs)) { + auto fastMathAttrB = dyn_cast(rhs); + if (!fastMathAttrB) { + return false; + } + + if (fastMathAttrA.getValue() != fastMathAttrB.getValue()) { + return false; + } + } else if (auto cconvAttrA = dyn_cast(lhs)) { + auto cconvAttrB = dyn_cast(rhs); + if (!cconvAttrB) { + return false; + } + + if (cconvAttrA.getCallingConv() != cconvAttrB.getCallingConv()) { + return false; + } + } else if (auto modFlagAttrA = dyn_cast(lhs)) { + auto modFlagAttrB = dyn_cast(rhs); + if (!modFlagAttrB) { + return false; + } + + if (modFlagAttrA.getBehavior() != modFlagAttrB.getBehavior() || + modFlagAttrA.getKey() != modFlagAttrB.getKey() || + modFlagAttrA.getValue() != modFlagAttrB.getValue()) { + return false; + } } return true; diff --git a/mlir/lib/Support/Passes.cpp b/mlir/lib/Support/Passes.cpp index 70a5a499d1..bbd56a68c7 100644 --- a/mlir/lib/Support/Passes.cpp +++ b/mlir/lib/Support/Passes.cpp @@ -28,7 +28,7 @@ static void addSimplificationPasses(PassManager& pm) { pm.addPass(createCSEPass()); } -static LogicalResult +LogicalResult runWithPassManager(ModuleOp module, const function_ref populatePasses, const StringRef errorMessage) { @@ -53,10 +53,11 @@ void populateQCOCleanupPipeline(PassManager& pm) { pm.addPass(createRemoveDeadValuesPass()); } -void populateQIRCleanupPipeline(PassManager& pm) { +void populateQIRCleanupPipeline(PassManager& pm, bool useAdaptive) { addSimplificationPasses(pm); pm.addPass(qir::createQIRCleanupPass()); pm.addPass(createRemoveDeadValuesPass()); + pm.addPass(qir::createQIRSetAttributesAndMetadata({useAdaptive})); } [[nodiscard]] LogicalResult runQCCleanupPipeline(ModuleOp module) { @@ -69,7 +70,10 @@ void populateQIRCleanupPipeline(PassManager& pm) { "Failed to run QCO cleanup pipeline."); } -[[nodiscard]] LogicalResult runQIRCleanupPipeline(ModuleOp module) { - return runWithPassManager(module, populateQIRCleanupPipeline, - "Failed to run QIR cleanup pipeline."); +[[nodiscard]] LogicalResult runQIRCleanupPipeline(ModuleOp module, + bool useAdaptive) { + return runWithPassManager( + module, + [&](PassManager& pm) { populateQIRCleanupPipeline(pm, useAdaptive); }, + "Failed to run QIR cleanup pipeline."); } diff --git a/mlir/unittests/Compiler/test_compiler_pipeline.cpp b/mlir/unittests/Compiler/test_compiler_pipeline.cpp index 191074ddac..ddc3e4ce4d 100644 --- a/mlir/unittests/Compiler/test_compiler_pipeline.cpp +++ b/mlir/unittests/Compiler/test_compiler_pipeline.cpp @@ -108,7 +108,7 @@ class CompilerPipelineTest auto module = mlir::qir::QIRProgramBuilder::build( context.get(), builder.fn, mlir::qir::QIRProgramBuilder::Profile::Adaptive); - EXPECT_TRUE(runQIRCleanupPipeline(module.get()).succeeded()); + EXPECT_TRUE(runQIRCleanupPipeline(module.get(), true).succeeded()); return module; } diff --git a/mlir/unittests/Conversion/QCToQIR/QCToQIRAdaptive/test_qc_to_qir_adaptive.cpp b/mlir/unittests/Conversion/QCToQIR/QCToQIRAdaptive/test_qc_to_qir_adaptive.cpp index 1a1f5a3639..509f1ff321 100644 --- a/mlir/unittests/Conversion/QCToQIR/QCToQIRAdaptive/test_qc_to_qir_adaptive.cpp +++ b/mlir/unittests/Conversion/QCToQIR/QCToQIRAdaptive/test_qc_to_qir_adaptive.cpp @@ -100,7 +100,7 @@ TEST_P(QCToQIRAdaptiveTest, ProgramEquivalence) { printer.record(program.get(), "Converted QIR IR" + name); EXPECT_TRUE(verify(*program).succeeded()); - EXPECT_TRUE(runQIRCleanupPipeline(program.get()).succeeded()); + EXPECT_TRUE(runQIRCleanupPipeline(program.get(), true).succeeded()); printer.record(program.get(), "Canonicalized Converted QIR IR" + name); EXPECT_TRUE(verify(*program).succeeded()); @@ -111,7 +111,7 @@ TEST_P(QCToQIRAdaptiveTest, ProgramEquivalence) { printer.record(reference.get(), "Reference QIR IR" + name); EXPECT_TRUE(verify(*reference).succeeded()); - EXPECT_TRUE(runQIRCleanupPipeline(reference.get()).succeeded()); + EXPECT_TRUE(runQIRCleanupPipeline(reference.get(), true).succeeded()); printer.record(reference.get(), "Canonicalized Reference QIR IR" + name); EXPECT_TRUE(verify(*reference).succeeded()); diff --git a/mlir/unittests/Conversion/QCToQIR/QCToQIRBase/test_qc_to_qir_base.cpp b/mlir/unittests/Conversion/QCToQIR/QCToQIRBase/test_qc_to_qir_base.cpp index 409d1d70de..b637247132 100644 --- a/mlir/unittests/Conversion/QCToQIR/QCToQIRBase/test_qc_to_qir_base.cpp +++ b/mlir/unittests/Conversion/QCToQIR/QCToQIRBase/test_qc_to_qir_base.cpp @@ -98,7 +98,7 @@ TEST_P(QCToQIRBaseTest, ProgramEquivalence) { printer.record(program.get(), "Converted QIR IR" + name); EXPECT_TRUE(verify(*program).succeeded()); - EXPECT_TRUE(runQIRCleanupPipeline(program.get()).succeeded()); + EXPECT_TRUE(runQIRCleanupPipeline(program.get(), false).succeeded()); printer.record(program.get(), "Canonicalized Converted QIR IR" + name); EXPECT_TRUE(verify(*program).succeeded()); @@ -109,7 +109,7 @@ TEST_P(QCToQIRBaseTest, ProgramEquivalence) { printer.record(reference.get(), "Reference QIR IR" + name); EXPECT_TRUE(verify(*reference).succeeded()); - EXPECT_TRUE(runQIRCleanupPipeline(reference.get()).succeeded()); + EXPECT_TRUE(runQIRCleanupPipeline(reference.get(), false).succeeded()); printer.record(reference.get(), "Canonicalized Reference QIR IR" + name); EXPECT_TRUE(verify(*reference).succeeded()); diff --git a/mlir/unittests/Dialect/QIR/IR/test_qir_ir.cpp b/mlir/unittests/Dialect/QIR/IR/test_qir_ir.cpp index a753acafd5..765b249a41 100644 --- a/mlir/unittests/Dialect/QIR/IR/test_qir_ir.cpp +++ b/mlir/unittests/Dialect/QIR/IR/test_qir_ir.cpp @@ -72,7 +72,7 @@ TEST_P(QIRTest, ProgramEquivalence) { printer.record(program.get(), "Original QIR IR" + name); EXPECT_TRUE(verify(*program).succeeded()); - EXPECT_TRUE(runQIRCleanupPipeline(program.get()).succeeded()); + EXPECT_TRUE(runQIRCleanupPipeline(program.get(), true).succeeded()); printer.record(program.get(), "Canonicalized QIR IR" + name); EXPECT_TRUE(verify(*program).succeeded()); @@ -82,7 +82,7 @@ TEST_P(QIRTest, ProgramEquivalence) { printer.record(reference.get(), "Reference QIR IR" + name); EXPECT_TRUE(verify(*reference).succeeded()); - EXPECT_TRUE(runQIRCleanupPipeline(reference.get()).succeeded()); + EXPECT_TRUE(runQIRCleanupPipeline(reference.get(), true).succeeded()); printer.record(reference.get(), "Canonicalized Reference QIR IR" + name); EXPECT_TRUE(verify(*reference).succeeded());