diff --git a/clang-tools-extra/clang-tidy/.clang-format b/clang-tools-extra/clang-tidy/.clang-format index 5b5066116bbaa..e97ba0573dd1e 100644 --- a/clang-tools-extra/clang-tidy/.clang-format +++ b/clang-tools-extra/clang-tidy/.clang-format @@ -1,3 +1,4 @@ BasedOnStyle: LLVM QualifierAlignment: Left LineEnding: LF +InsertNewlineAtEOF: true diff --git a/clang/include/clang/CIR/MissingFeatures.h b/clang/include/clang/CIR/MissingFeatures.h index 5f32abca70baa..567c79a27c07b 100644 --- a/clang/include/clang/CIR/MissingFeatures.h +++ b/clang/include/clang/CIR/MissingFeatures.h @@ -226,6 +226,7 @@ struct MissingFeatures { static bool cleanupAppendInsts() { return false; } static bool cleanupBranchThrough() { return false; } static bool cleanupIndexAndBIAdjustment() { return false; } + static bool cleanupWithPreservedValues() { return false; } static bool cleanupsToDeactivate() { return false; } static bool constEmitterAggILE() { return false; } static bool constEmitterArrayILE() { return false; } diff --git a/clang/lib/AST/ByteCode/Pointer.h b/clang/lib/AST/ByteCode/Pointer.h index cd738ce8b2a3e..6efec48df71cb 100644 --- a/clang/lib/AST/ByteCode/Pointer.h +++ b/clang/lib/AST/ByteCode/Pointer.h @@ -830,6 +830,9 @@ class Pointer { inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, const Pointer &P) { P.print(OS); + OS << ' '; + if (const Descriptor *D = P.getFieldDesc()) + D->dump(OS); return OS; } diff --git a/clang/lib/CIR/CodeGen/CIRGenFunction.cpp b/clang/lib/CIR/CodeGen/CIRGenFunction.cpp index 885a32cf16862..f1be14222434f 100644 --- a/clang/lib/CIR/CodeGen/CIRGenFunction.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenFunction.cpp @@ -921,6 +921,13 @@ LValue CIRGenFunction::emitLValue(const Expr *e) { case Expr::CXXOperatorCallExprClass: case Expr::UserDefinedLiteralClass: return emitCallExprLValue(cast(e)); + case Expr::ExprWithCleanupsClass: { + const auto *cleanups = cast(e); + RunCleanupsScope scope(*this); + LValue lv = emitLValue(cleanups->getSubExpr()); + assert(!cir::MissingFeatures::cleanupWithPreservedValues()); + return lv; + } case Expr::ParenExprClass: return emitLValue(cast(e)->getSubExpr()); case Expr::GenericSelectionExprClass: diff --git a/clang/lib/Format/FormatToken.h b/clang/lib/Format/FormatToken.h index d833130a538f1..56abd702aaafe 100644 --- a/clang/lib/Format/FormatToken.h +++ b/clang/lib/Format/FormatToken.h @@ -1170,6 +1170,7 @@ struct AdditionalKeywords { kw_checker = &IdentTable.get("checker"); kw_clocking = &IdentTable.get("clocking"); kw_constraint = &IdentTable.get("constraint"); + kw_context = &IdentTable.get("context"); kw_cover = &IdentTable.get("cover"); kw_covergroup = &IdentTable.get("covergroup"); kw_coverpoint = &IdentTable.get("coverpoint"); @@ -1325,50 +1326,138 @@ struct AdditionalKeywords { // Some keywords are not included here because they don't need special // treatment like `showcancelled` or they should be treated as identifiers // like `int` and `logic`. - VerilogExtraKeywords = std::unordered_set( - {kw_always, kw_always_comb, kw_always_ff, - kw_always_latch, kw_assert, kw_assign, - kw_assume, kw_automatic, kw_before, - kw_begin, kw_bins, kw_binsof, - kw_casex, kw_casez, kw_celldefine, - kw_checker, kw_clocking, kw_constraint, - kw_cover, kw_covergroup, kw_coverpoint, - kw_disable, kw_dist, kw_edge, - kw_end, kw_endcase, kw_endchecker, - kw_endclass, kw_endclocking, kw_endfunction, - kw_endgenerate, kw_endgroup, kw_endinterface, - kw_endmodule, kw_endpackage, kw_endprimitive, - kw_endprogram, kw_endproperty, kw_endsequence, - kw_endspecify, kw_endtable, kw_endtask, - kw_extends, kw_final, kw_foreach, - kw_forever, kw_fork, kw_function, - kw_generate, kw_highz0, kw_highz1, - kw_iff, kw_ifnone, kw_ignore_bins, - kw_illegal_bins, kw_implements, kw_import, - kw_initial, kw_inout, kw_input, - kw_inside, kw_interconnect, kw_interface, - kw_intersect, kw_join, kw_join_any, - kw_join_none, kw_large, kw_let, - kw_local, kw_localparam, kw_macromodule, - kw_matches, kw_medium, kw_negedge, - kw_output, kw_package, kw_packed, - kw_parameter, kw_posedge, kw_primitive, - kw_priority, kw_program, kw_property, - kw_pull0, kw_pull1, kw_pure, - kw_rand, kw_randc, kw_randcase, - kw_randsequence, kw_ref, kw_repeat, - kw_sample, kw_scalared, kw_sequence, - kw_small, kw_soft, kw_solve, - kw_specify, kw_specparam, kw_strong0, - kw_strong1, kw_supply0, kw_supply1, - kw_table, kw_tagged, kw_task, - kw_tri, kw_tri0, kw_tri1, - kw_triand, kw_trior, kw_trireg, - kw_unique, kw_unique0, kw_uwire, - kw_var, kw_vectored, kw_wait, - kw_wand, kw_weak0, kw_weak1, - kw_wildcard, kw_wire, kw_with, - kw_wor, kw_verilogHash, kw_verilogHashHash}); + VerilogExtraKeywords = + std::unordered_set({kw_always, + kw_always_comb, + kw_always_ff, + kw_always_latch, + kw_assert, + kw_assign, + kw_assume, + kw_automatic, + kw_before, + kw_begin, + kw_bins, + kw_binsof, + kw_casex, + kw_casez, + kw_celldefine, + kw_checker, + kw_clocking, + kw_constraint, + kw_context, + kw_cover, + kw_covergroup, + kw_coverpoint, + kw_disable, + kw_dist, + kw_edge, + kw_end, + kw_endcase, + kw_endchecker, + kw_endclass, + kw_endclocking, + kw_endfunction, + kw_endgenerate, + kw_endgroup, + kw_endinterface, + kw_endmodule, + kw_endpackage, + kw_endprimitive, + kw_endprogram, + kw_endproperty, + kw_endsequence, + kw_endspecify, + kw_endtable, + kw_endtask, + kw_extends, + kw_final, + kw_foreach, + kw_forever, + kw_fork, + kw_function, + kw_generate, + kw_highz0, + kw_highz1, + kw_iff, + kw_ifnone, + kw_ignore_bins, + kw_illegal_bins, + kw_implements, + kw_import, + kw_initial, + kw_inout, + kw_input, + kw_inside, + kw_interconnect, + kw_interface, + kw_intersect, + kw_join, + kw_join_any, + kw_join_none, + kw_large, + kw_let, + kw_local, + kw_localparam, + kw_macromodule, + kw_matches, + kw_medium, + kw_module, + kw_negedge, + kw_output, + kw_package, + kw_packed, + kw_parameter, + kw_posedge, + kw_primitive, + kw_priority, + kw_program, + kw_property, + kw_pull0, + kw_pull1, + kw_pure, + kw_rand, + kw_randc, + kw_randcase, + kw_randsequence, + kw_ref, + kw_repeat, + kw_sample, + kw_scalared, + kw_sequence, + kw_small, + kw_soft, + kw_solve, + kw_specify, + kw_specparam, + kw_strong0, + kw_strong1, + kw_supply0, + kw_supply1, + kw_table, + kw_tagged, + kw_task, + kw_tri, + kw_tri0, + kw_tri1, + kw_triand, + kw_trior, + kw_trireg, + kw_unique, + kw_unique0, + kw_uwire, + kw_var, + kw_vectored, + kw_wait, + kw_wand, + kw_weak0, + kw_weak1, + kw_wildcard, + kw_wire, + kw_with, + kw_wor, + kw_verilogHash, + kw_verilogHashHash}); TableGenExtraKeywords = std::unordered_set({ kw_assert, @@ -1516,6 +1605,7 @@ struct AdditionalKeywords { IdentifierInfo *kw_checker; IdentifierInfo *kw_clocking; IdentifierInfo *kw_constraint; + IdentifierInfo *kw_context; IdentifierInfo *kw_cover; IdentifierInfo *kw_covergroup; IdentifierInfo *kw_coverpoint; @@ -1800,11 +1890,13 @@ struct AdditionalKeywords { case tok::kw_continue: case tok::kw_default: case tok::kw_do: - case tok::kw_extern: case tok::kw_else: case tok::kw_enum: + case tok::kw_export: + case tok::kw_extern: case tok::kw_for: case tok::kw_if: + case tok::kw_import: case tok::kw_restrict: case tok::kw_signed: case tok::kw_static: diff --git a/clang/lib/Format/UnwrappedLineParser.cpp b/clang/lib/Format/UnwrappedLineParser.cpp index 8b7dd02d548af..50edca43ebb92 100644 --- a/clang/lib/Format/UnwrappedLineParser.cpp +++ b/clang/lib/Format/UnwrappedLineParser.cpp @@ -1592,15 +1592,14 @@ void UnwrappedLineParser::parseStructuralElement( parseTryCatch(); return; case tok::kw_extern: - nextToken(); if (Style.isVerilog()) { - // In Verilog and extern module declaration looks like a start of module. + // In Verilog an extern module declaration looks like a start of module. // But there is no body and endmodule. So we handle it separately. - if (Keywords.isVerilogHierarchy(*FormatTok)) { - parseVerilogHierarchyHeader(); - return; - } - } else if (FormatTok->is(tok::string_literal)) { + parseVerilogExtern(); + return; + } + nextToken(); + if (FormatTok->is(tok::string_literal)) { nextToken(); if (FormatTok->is(tok::l_brace)) { if (Style.BraceWrapping.AfterExternBlock) @@ -1625,6 +1624,10 @@ void UnwrappedLineParser::parseStructuralElement( parseJavaScriptEs6ImportExport(); return; } + if (Style.isVerilog()) { + parseVerilogExtern(); + return; + } if (IsCpp) { nextToken(); if (FormatTok->is(tok::kw_namespace)) { @@ -1673,6 +1676,10 @@ void UnwrappedLineParser::parseStructuralElement( addUnwrappedLine(); return; } + if (Style.isVerilog()) { + parseVerilogExtern(); + return; + } if (IsCpp && parseModuleImport()) return; } @@ -4559,6 +4566,23 @@ void UnwrappedLineParser::parseVerilogCaseLabel() { Line->Level = OrigLevel; } +void UnwrappedLineParser::parseVerilogExtern() { + assert( + FormatTok->isOneOf(tok::kw_extern, tok::kw_export, Keywords.kw_import)); + nextToken(); + // "DPI-C" + if (FormatTok->is(tok::string_literal)) + nextToken(); + if (FormatTok->isOneOf(Keywords.kw_context, Keywords.kw_pure)) + nextToken(); + if (Keywords.isVerilogIdentifier(*FormatTok)) + nextToken(); + if (FormatTok->is(tok::equal)) + nextToken(); + if (Keywords.isVerilogHierarchy(*FormatTok)) + parseVerilogHierarchyHeader(); +} + bool UnwrappedLineParser::containsExpansion(const UnwrappedLine &Line) const { for (const auto &N : Line.Tokens) { if (N.Tok->MacroCtx) diff --git a/clang/lib/Format/UnwrappedLineParser.h b/clang/lib/Format/UnwrappedLineParser.h index 8b8ad84896f1a..0161a5063ad40 100644 --- a/clang/lib/Format/UnwrappedLineParser.h +++ b/clang/lib/Format/UnwrappedLineParser.h @@ -205,6 +205,8 @@ class UnwrappedLineParser { unsigned parseVerilogHierarchyHeader(); void parseVerilogTable(); void parseVerilogCaseLabel(); + // For import, export, and extern. + void parseVerilogExtern(); std::optional, 1>> parseMacroCall(); diff --git a/clang/lib/Frontend/CompilerInvocation.cpp b/clang/lib/Frontend/CompilerInvocation.cpp index a95796924311b..54b302e829e1f 100644 --- a/clang/lib/Frontend/CompilerInvocation.cpp +++ b/clang/lib/Frontend/CompilerInvocation.cpp @@ -5322,6 +5322,7 @@ void CompilerInvocationBase::visitPathsImpl( RETURN_IF(Input.File); } + // TODO: Also report output files such as FrontendOpts.OutputFile; RETURN_IF(FrontendOpts.CodeCompletionAt.FileName); RETURN_IF_MANY(FrontendOpts.ModuleMapFiles); RETURN_IF_MANY(FrontendOpts.ModuleFiles); diff --git a/clang/test/CIR/CodeGen/temporary-materialization.cpp b/clang/test/CIR/CodeGen/temporary-materialization.cpp new file mode 100644 index 0000000000000..b936ddfe20bf5 --- /dev/null +++ b/clang/test/CIR/CodeGen/temporary-materialization.cpp @@ -0,0 +1,86 @@ +// RUN: %clang_cc1 -std=c++17 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir %s -o %t.cir +// RUN: FileCheck --input-file=%t.cir %s --check-prefix=CIR +// RUN: %clang_cc1 -std=c++17 -triple x86_64-unknown-linux-gnu -fclangir -emit-llvm %s -o %t-cir.ll +// RUN: FileCheck --input-file=%t-cir.ll %s --check-prefix=LLVM +// RUN: %clang_cc1 -std=c++17 -triple x86_64-unknown-linux-gnu -emit-llvm %s -o %t.ll +// RUN: FileCheck --input-file=%t.ll %s --check-prefix=OGCG + +int make_int(); + +int test() { + const int &x = make_int(); + return x; +} + +// CIR: cir.func {{.*}} @_Z4testv() +// CIR: %[[TEMP_SLOT:.*]] = cir.alloca !s32i, !cir.ptr, ["ref.tmp0", init] +// CIR-NEXT: %[[X:.*]] = cir.alloca !cir.ptr, !cir.ptr>, ["x", init, const] +// CIR-NEXT: %[[TEMP_VALUE:.*]] = cir.call @_Z8make_intv() : () -> !s32i +// CIR-NEXT: cir.store{{.*}} %[[TEMP_VALUE]], %[[TEMP_SLOT]] +// CIR-NEXT: cir.store{{.*}} %[[TEMP_SLOT]], %[[X]] + +// LLVM: define {{.*}} i32 @_Z4testv() +// LLVM: %[[RETVAL:.*]] = alloca i32 +// LLVM: %[[TEMP_SLOT:.*]] = alloca i32 +// LLVM: %[[X:.*]] = alloca ptr +// LLVM: %[[TEMP_VALUE:.*]] = call i32 @_Z8make_intv() +// LLVM: store i32 %[[TEMP_VALUE]], ptr %[[TEMP_SLOT]] +// LLVM: store ptr %[[TEMP_SLOT]], ptr %[[X]] + +// OGCG: define {{.*}} i32 @_Z4testv() +// OGCG: %[[X:.*]] = alloca ptr +// OGCG: %[[TEMP_SLOT:.*]] = alloca i32 +// OGCG: %[[TEMP_VALUE:.*]] = call noundef i32 @_Z8make_intv() +// OGCG: store i32 %[[TEMP_VALUE]], ptr %[[TEMP_SLOT]] +// OGCG: store ptr %[[TEMP_SLOT]], ptr %[[X]] + +int test_scoped() { + int x = make_int(); + { + const int &y = make_int(); + x = y; + } + return x; +} + +// CIR: cir.func {{.*}} @_Z11test_scopedv() +// CIR: %[[X:.*]] = cir.alloca !s32i, !cir.ptr, ["x", init] +// CIR: cir.scope { +// CIR-NEXT: %[[TEMP_SLOT:.*]] = cir.alloca !s32i, !cir.ptr, ["ref.tmp0", init] +// CIR-NEXT: %[[Y_ADDR:.*]] = cir.alloca !cir.ptr, !cir.ptr>, ["y", init, const] +// CIR-NEXT: %[[TEMP_VALUE:.*]] = cir.call @_Z8make_intv() : () -> !s32i +// CIR-NEXT: cir.store{{.*}} %[[TEMP_VALUE]], %[[TEMP_SLOT]] : !s32i, !cir.ptr +// CIR-NEXT: cir.store{{.*}} %[[TEMP_SLOT]], %[[Y_ADDR]] : !cir.ptr, !cir.ptr> +// CIR-NEXT: %[[Y_REF:.*]] = cir.load %[[Y_ADDR]] : !cir.ptr>, !cir.ptr +// CIR-NEXT: %[[Y_VALUE:.*]] = cir.load{{.*}} %[[Y_REF]] : !cir.ptr, !s32i +// CIR-NEXT: cir.store{{.*}} %[[Y_VALUE]], %[[X]] : !s32i, !cir.ptr +// CIR-NEXT: } + +// LLVM: define {{.*}} i32 @_Z11test_scopedv() +// LLVM: %[[TEMP_SLOT:.*]] = alloca i32 +// LLVM: %[[Y_ADDR:.*]] = alloca ptr +// LLVM: %[[RETVAL:.*]] = alloca i32 +// LLVM: %[[X:.*]] = alloca i32 +// LLVM: %[[TEMP_VALUE1:.*]] = call i32 @_Z8make_intv() +// LLVM: store i32 %[[TEMP_VALUE1]], ptr %[[X]] +// LLVM: br label %[[SCOPE_LABEL:.*]] +// LLVM: [[SCOPE_LABEL]]: +// LLVM: %[[TEMP_VALUE2:.*]] = call i32 @_Z8make_intv() +// LLVM: store i32 %[[TEMP_VALUE2]], ptr %[[TEMP_SLOT]] +// LLVM: store ptr %[[TEMP_SLOT]], ptr %[[Y_ADDR]] +// LLVM: %[[Y_REF:.*]] = load ptr, ptr %[[Y_ADDR]] +// LLVM: %[[Y_VALUE:.*]] = load i32, ptr %[[Y_REF]] +// LLVM: store i32 %[[Y_VALUE]], ptr %[[X]] + +// OGCG: define {{.*}} i32 @_Z11test_scopedv() +// OGCG: %[[X:.*]] = alloca i32 +// OGCG: %[[Y_ADDR:.*]] = alloca ptr +// OGCG: %[[TEMP_SLOT:.*]] = alloca i32 +// OGCG: %[[TEMP_VALUE1:.*]] = call noundef i32 @_Z8make_intv() +// OGCG: store i32 %[[TEMP_VALUE1]], ptr %[[X]] +// OGCG: %[[TEMP_VALUE2:.*]] = call noundef i32 @_Z8make_intv() +// OGCG: store i32 %[[TEMP_VALUE2]], ptr %[[TEMP_SLOT]] +// OGCG: store ptr %[[TEMP_SLOT]], ptr %[[Y_ADDR]] +// OGCG: %[[Y_REF:.*]] = load ptr, ptr %[[Y_ADDR]] +// OGCG: %[[Y_VALUE:.*]] = load i32, ptr %[[Y_REF]] +// OGCG: store i32 %[[Y_VALUE]], ptr %[[X]] diff --git a/clang/tools/offload-arch/AMDGPUArchByHIP.cpp b/clang/tools/offload-arch/AMDGPUArchByHIP.cpp index ff39a85d15628..d7f6d79b135df 100644 --- a/clang/tools/offload-arch/AMDGPUArchByHIP.cpp +++ b/clang/tools/offload-arch/AMDGPUArchByHIP.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Sequence.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/ConvertUTF.h" #include "llvm/Support/DynamicLibrary.h" @@ -32,22 +33,54 @@ using namespace llvm; -typedef struct { +// R0600 struct layout (HIP 6.x+) +typedef struct alignas(8) { + char padding[1160]; + char gcnArchName[256]; + char padding2[56]; +} hipDeviceProp_tR0600; + +// R0000 struct layout (legacy) +typedef struct alignas(8) { char padding[396]; char gcnArchName[256]; char padding2[1024]; -} hipDeviceProp_t; +} hipDeviceProp_tR0000; typedef enum { hipSuccess = 0, } hipError_t; typedef hipError_t (*hipGetDeviceCount_t)(int *); -typedef hipError_t (*hipDeviceGet_t)(int *, int); -typedef hipError_t (*hipGetDeviceProperties_t)(hipDeviceProp_t *, int); +typedef hipError_t (*hipGetDevicePropertiesR0600_t)(hipDeviceProp_tR0600 *, + int); +typedef hipError_t (*hipGetDevicePropertiesR0000_t)(hipDeviceProp_tR0000 *, + int); +typedef hipError_t (*hipGetDeviceProperties_t)(hipDeviceProp_tR0000 *, int); +typedef hipError_t (*hipRuntimeGetVersion_t)(int *); +typedef const char *(*hipGetErrorString_t)(hipError_t); extern cl::opt Verbose; +cl::OptionCategory AMDGPUArchByHIPCategory("amdgpu-arch (HIP) options"); + +enum class HipApiVersion { + Auto, // Automatic fallback (R0600 -> R0000 -> unversioned) + R0600, // Force R0600 API (HIP 6.x+) + R0000, // Force R0000 API (legacy HIP) + Unversioned // Force unversioned API (very old HIP) +}; + +static cl::opt HipApi( + "hip-api-version", cl::desc("Select HIP API version for device properties"), + cl::values(clEnumValN(HipApiVersion::Auto, "auto", + "Auto-detect (R0600 -> R0000 -> unversioned)"), + clEnumValN(HipApiVersion::R0600, "r0600", "Force R0600 API"), + clEnumValN(HipApiVersion::R0000, "r0000", "Force R0000 API"), + clEnumValN(HipApiVersion::Unversioned, "unversioned", + "Force unversioned API")), + cl::init(HipApiVersion::Auto), cl::cat(AMDGPUArchByHIPCategory)); + #ifdef _WIN32 static std::vector getSearchPaths() { std::vector Paths; @@ -177,6 +210,9 @@ int printGPUsByHIP() { return 1; } + if (Verbose) + outs() << "Successfully loaded HIP runtime library\n"; + #define DYNAMIC_INIT_HIP(SYMBOL) \ { \ void *SymbolPtr = DynlibHandle->getAddressOfSymbol(#SYMBOL); \ @@ -184,42 +220,153 @@ int printGPUsByHIP() { llvm::errs() << "Failed to find symbol " << #SYMBOL << '\n'; \ return 1; \ } \ + if (Verbose) \ + outs() << "Found symbol: " << #SYMBOL << '\n'; \ SYMBOL = reinterpret_cast(SymbolPtr); \ } hipGetDeviceCount_t hipGetDeviceCount; - hipDeviceGet_t hipDeviceGet; - hipGetDeviceProperties_t hipGetDeviceProperties; + hipRuntimeGetVersion_t hipRuntimeGetVersion = nullptr; + hipGetDevicePropertiesR0600_t hipGetDevicePropertiesR0600 = nullptr; + hipGetDevicePropertiesR0000_t hipGetDevicePropertiesR0000 = nullptr; + hipGetDeviceProperties_t hipGetDeviceProperties = nullptr; + hipGetErrorString_t hipGetErrorString = nullptr; DYNAMIC_INIT_HIP(hipGetDeviceCount); - DYNAMIC_INIT_HIP(hipDeviceGet); - DYNAMIC_INIT_HIP(hipGetDeviceProperties); #undef DYNAMIC_INIT_HIP - int deviceCount; - hipError_t err = hipGetDeviceCount(&deviceCount); - if (err != hipSuccess) { - llvm::errs() << "Failed to get device count\n"; + auto LoadSymbol = [&](const char *Name, auto &FuncPtr, + const char *Desc = "") { + void *Sym = DynlibHandle->getAddressOfSymbol(Name); + if (Sym) { + FuncPtr = reinterpret_cast(Sym); + if (Verbose) + outs() << "Found symbol: " << Name << (Desc[0] ? " " : "") << Desc + << '\n'; + return true; + } + return false; + }; + + LoadSymbol("hipGetErrorString", hipGetErrorString); + + if (LoadSymbol("hipRuntimeGetVersion", hipRuntimeGetVersion)) { + int RuntimeVersion = 0; + if (hipRuntimeGetVersion(&RuntimeVersion) == hipSuccess) { + int Major = RuntimeVersion / 10000000; + int Minor = (RuntimeVersion / 100000) % 100; + int Patch = RuntimeVersion % 100000; + if (Verbose) + outs() << "HIP Runtime Version: " << Major << "." << Minor << "." + << Patch << '\n'; + } + } + + LoadSymbol("hipGetDevicePropertiesR0600", hipGetDevicePropertiesR0600, + "(HIP 6.x+ API)"); + LoadSymbol("hipGetDevicePropertiesR0000", hipGetDevicePropertiesR0000, + "(legacy API)"); + if (!hipGetDevicePropertiesR0600 && !hipGetDevicePropertiesR0000) + LoadSymbol("hipGetDeviceProperties", hipGetDeviceProperties, + "(unversioned legacy API)"); + + int DeviceCount; + if (Verbose) + outs() << "Calling hipGetDeviceCount...\n"; + hipError_t Err = hipGetDeviceCount(&DeviceCount); + if (Err != hipSuccess) { + llvm::errs() << "Failed to get device count"; + if (hipGetErrorString) { + llvm::errs() << ": " << hipGetErrorString(Err); + } + llvm::errs() << " (error code: " << Err << ")\n"; return 1; } - for (int i = 0; i < deviceCount; ++i) { - int deviceId; - err = hipDeviceGet(&deviceId, i); - if (err != hipSuccess) { - llvm::errs() << "Failed to get device id for ordinal " << i << '\n'; - return 1; + if (Verbose) + outs() << "Found " << DeviceCount << " device(s)\n"; + + auto TryGetProperties = [&](auto *ApiFunc, auto *DummyProp, const char *Name, + int DeviceId) -> std::string { + if (!ApiFunc) + return ""; + + if (Verbose) + outs() << "Using " << Name << "...\n"; + + using PropType = std::remove_pointer_t; + PropType Prop; + hipError_t Err = ApiFunc(&Prop, DeviceId); + + if (Err == hipSuccess) { + if (Verbose) { + outs() << Name << " struct: sizeof = " << sizeof(PropType) + << " bytes, offsetof(gcnArchName) = " + << offsetof(PropType, gcnArchName) << " bytes\n"; + } + return Prop.gcnArchName; } - hipDeviceProp_t prop; - err = hipGetDeviceProperties(&prop, deviceId); - if (err != hipSuccess) { - llvm::errs() << "Failed to get device properties for device " << deviceId - << '\n'; + if (Verbose) + llvm::errs() << Name << " failed (error code: " << Err << ")\n"; + return ""; + }; + + for (auto I : llvm::seq(DeviceCount)) { + if (Verbose) + outs() << "Processing device " << I << "...\n"; + + std::string ArchName; + auto TryR0600 = [&](int Dev) -> bool { + if (!hipGetDevicePropertiesR0600) + return false; + ArchName = TryGetProperties(hipGetDevicePropertiesR0600, + (hipDeviceProp_tR0600 *)nullptr, + "R0600 API (HIP 6.x+)", Dev); + return !ArchName.empty(); + }; + auto TryR0000 = [&](int Dev) -> bool { + if (!hipGetDevicePropertiesR0000) + return false; + ArchName = TryGetProperties(hipGetDevicePropertiesR0000, + (hipDeviceProp_tR0000 *)nullptr, + "R0000 API (legacy HIP)", Dev); + return !ArchName.empty(); + }; + auto TryUnversioned = [&](int Dev) -> bool { + if (!hipGetDeviceProperties) + return false; + ArchName = TryGetProperties(hipGetDeviceProperties, + (hipDeviceProp_tR0000 *)nullptr, + "unversioned API (very old HIP)", Dev); + return !ArchName.empty(); + }; + + [[maybe_unused]] bool OK; + switch (HipApi) { + case HipApiVersion::Auto: + OK = TryR0600(I) || TryR0000(I) || TryUnversioned(I); + break; + case HipApiVersion::R0600: + OK = TryR0600(I); + break; + case HipApiVersion::R0000: + OK = TryR0000(I); + break; + case HipApiVersion::Unversioned: + OK = TryUnversioned(I); + } + + if (ArchName.empty()) { + llvm::errs() << "Failed to get device properties for device " << I + << " - no APIs available or all failed\n"; return 1; } - llvm::outs() << prop.gcnArchName << '\n'; + + if (Verbose) + outs() << "Device " << I << " arch name: "; + llvm::outs() << ArchName << '\n'; } return 0; diff --git a/clang/tools/offload-arch/OffloadArch.cpp b/clang/tools/offload-arch/OffloadArch.cpp index 3c5131eb7c06c..91493676918cb 100644 --- a/clang/tools/offload-arch/OffloadArch.cpp +++ b/clang/tools/offload-arch/OffloadArch.cpp @@ -17,6 +17,8 @@ static cl::opt Help("h", cl::desc("Alias for -help"), cl::Hidden); // Mark all our options with this category. static cl::OptionCategory OffloadArchCategory("offload-arch options"); +extern cl::OptionCategory AMDGPUArchByHIPCategory; + enum VendorName { all, amdgpu, @@ -62,7 +64,7 @@ const std::array>, 3> VendorTable{ {VendorName::intel, printIntel}}}; int main(int argc, char *argv[]) { - cl::HideUnrelatedOptions(OffloadArchCategory); + cl::HideUnrelatedOptions({&OffloadArchCategory, &AMDGPUArchByHIPCategory}); cl::SetVersionPrinter(PrintVersion); cl::ParseCommandLineOptions( diff --git a/clang/unittests/Format/FormatTestVerilog.cpp b/clang/unittests/Format/FormatTestVerilog.cpp index 63e2cadfdd7a1..eee2bbdf551e6 100644 --- a/clang/unittests/Format/FormatTestVerilog.cpp +++ b/clang/unittests/Format/FormatTestVerilog.cpp @@ -676,6 +676,16 @@ TEST_F(FormatTestVerilog, Hierarchy) { " endprogram\n" "endmodule"); // Test that an extern declaration doesn't change the indentation. + verifyFormat("import \"DPI-C\" context MyCFunc = function integer MapID\n" + " (int portID);\n" + "x = x;"); + verifyFormat("export \"DPI-C\" function exported_sv_func;\n" + "x = x;"); + verifyFormat("import \"DPI-C\" function void f1\n" + " (input int i1,\n" + " pair i2,\n" + " output logic [63 : 0] o3);\n" + "x = x;"); verifyFormat("extern module x;\n" "x = x;"); // Test complex headers diff --git a/compiler-rt/lib/sanitizer_common/sanitizer_procmaps_mac.cpp b/compiler-rt/lib/sanitizer_common/sanitizer_procmaps_mac.cpp index a5ec85ae16460..f40fba6bf7151 100644 --- a/compiler-rt/lib/sanitizer_common/sanitizer_procmaps_mac.cpp +++ b/compiler-rt/lib/sanitizer_common/sanitizer_procmaps_mac.cpp @@ -45,7 +45,6 @@ struct MemoryMappedSegmentData { const char *current_load_cmd_addr; u32 lc_type; uptr base_virt_addr; - uptr addr_mask; }; template @@ -54,12 +53,60 @@ static void NextSectionLoad(LoadedModule *module, MemoryMappedSegmentData *data, const Section *sc = (const Section *)data->current_load_cmd_addr; data->current_load_cmd_addr += sizeof(Section); - uptr sec_start = (sc->addr & data->addr_mask) + data->base_virt_addr; + uptr sec_start = sc->addr + data->base_virt_addr; uptr sec_end = sec_start + sc->size; module->addAddressRange(sec_start, sec_end, /*executable=*/false, isWritable, sc->sectname); } +static bool VerifyMemoryMapping(MemoryMappingLayout* mapping) { + InternalMmapVector modules; + modules.reserve(128); // matches DumpProcessMap + mapping->DumpListOfModules(&modules); + + InternalMmapVector segments; + for (uptr i = 0; i < modules.size(); ++i) { + for (auto& range : modules[i].ranges()) { + segments.push_back(range); + } + } + + // Verify that none of the segments overlap: + // 1. Sort the segments by the start address + // 2. Check that every segment starts after the previous one ends. + Sort(segments.data(), segments.size(), + [](LoadedModule::AddressRange& a, LoadedModule::AddressRange& b) { + return a.beg < b.beg; + }); + + // To avoid spam, we only print the report message once-per-process. + static bool invalid_module_map_reported = false; + bool well_formed = true; + + for (size_t i = 1; i < segments.size(); i++) { + uptr cur_start = segments[i].beg; + uptr prev_end = segments[i - 1].end; + if (cur_start < prev_end) { + well_formed = false; + VReport(2, "Overlapping mappings: %s start = %p, %s end = %p\n", + segments[i].name, (void*)cur_start, segments[i - 1].name, + (void*)prev_end); + if (!invalid_module_map_reported) { + Report( + "WARN: Invalid dyld module map detected. This is most likely a bug " + "in the sanitizer.\n"); + Report("WARN: Backtraces may be unreliable.\n"); + invalid_module_map_reported = true; + } + } + } + + for (auto& m : modules) m.clear(); + + mapping->Reset(); + return well_formed; +} + void MemoryMappedSegment::AddAddressRanges(LoadedModule *module) { // Don't iterate over sections when the caller hasn't set up the // data pointer, when there are no sections, or when the segment @@ -85,6 +132,7 @@ void MemoryMappedSegment::AddAddressRanges(LoadedModule *module) { MemoryMappingLayout::MemoryMappingLayout(bool cache_enabled) { Reset(); + VerifyMemoryMapping(this); } MemoryMappingLayout::~MemoryMappingLayout() { @@ -190,6 +238,7 @@ typedef struct dyld_shared_cache_dylib_text_info extern bool _dyld_get_shared_cache_uuid(uuid_t uuid); extern const void *_dyld_get_shared_cache_range(size_t *length); +extern intptr_t _dyld_get_image_slide(const struct mach_header* mh); extern int dyld_shared_cache_iterate_text( const uuid_t cacheUuid, void (^callback)(const dyld_shared_cache_dylib_text_info *info)); @@ -258,23 +307,21 @@ static bool NextSegmentLoad(MemoryMappedSegment *segment, layout_data->current_load_cmd_count--; if (((const load_command *)lc)->cmd == kLCSegment) { const SegmentCommand* sc = (const SegmentCommand *)lc; - uptr base_virt_addr, addr_mask; - if (layout_data->current_image == kDyldImageIdx) { - base_virt_addr = (uptr)get_dyld_hdr(); - // vmaddr is masked with 0xfffff because on macOS versions < 10.12, - // it contains an absolute address rather than an offset for dyld. - // To make matters even more complicated, this absolute address - // isn't actually the absolute segment address, but the offset portion - // of the address is accurate when combined with the dyld base address, - // and the mask will give just this offset. - addr_mask = 0xfffff; - } else { + if (internal_strcmp(sc->segname, "__LINKEDIT") == 0) { + // The LINKEDIT sections are for internal linker use, and may alias + // with the LINKEDIT section for other modules. (If we included them, + // our memory map would contain overlappping sections.) + return false; + } + + uptr base_virt_addr; + if (layout_data->current_image == kDyldImageIdx) + base_virt_addr = (uptr)_dyld_get_image_slide(get_dyld_hdr()); + else base_virt_addr = (uptr)_dyld_get_image_vmaddr_slide(layout_data->current_image); - addr_mask = ~0; - } - segment->start = (sc->vmaddr & addr_mask) + base_virt_addr; + segment->start = sc->vmaddr + base_virt_addr; segment->end = segment->start + sc->vmsize; // Most callers don't need section information, so only fill this struct // when required. @@ -284,9 +331,9 @@ static bool NextSegmentLoad(MemoryMappedSegment *segment, (const char *)lc + sizeof(SegmentCommand); seg_data->lc_type = kLCSegment; seg_data->base_virt_addr = base_virt_addr; - seg_data->addr_mask = addr_mask; internal_strncpy(seg_data->name, sc->segname, ARRAY_SIZE(seg_data->name)); + seg_data->name[ARRAY_SIZE(seg_data->name) - 1] = 0; } // Return the initial protection. @@ -300,6 +347,7 @@ static bool NextSegmentLoad(MemoryMappedSegment *segment, ? kDyldPath : _dyld_get_image_name(layout_data->current_image); internal_strncpy(segment->filename, src, segment->filename_size); + segment->filename[segment->filename_size - 1] = 0; } segment->arch = layout_data->current_arch; internal_memcpy(segment->uuid, layout_data->current_uuid, kModuleUUIDSize); diff --git a/compiler-rt/test/asan/TestCases/Darwin/asan-verify-module-map.cpp b/compiler-rt/test/asan/TestCases/Darwin/asan-verify-module-map.cpp new file mode 100644 index 0000000000000..15be1cd6754c3 --- /dev/null +++ b/compiler-rt/test/asan/TestCases/Darwin/asan-verify-module-map.cpp @@ -0,0 +1,25 @@ +// This test simply checks that the "Invalid dyld module map" warning is not printed +// in the output of a backtrace. + +// RUN: %clangxx_asan -DSHARED_LIB -g %s -dynamiclib -o %t.dylib +// RUN: %clangxx_asan -O0 -g %s %t.dylib -o %t.executable +// RUN: %env_asan_opts="print_module_map=2" not %run %t.executable 2>&1 | FileCheck %s -DDYLIB=%{t:stem}.tmp.dylib + +// CHECK-NOT: WARN: Invalid dyld module map +// CHECK-DAG: 0x{{.*}}-0x{{.*}} {{.*}}[[DYLIB]] +// CHECK-DAG: 0x{{.*}}-0x{{.*}} {{.*}}libsystem + +#ifdef SHARED_LIB +extern "C" void foo(int *a) { *a = 5; } +#else +# include + +extern "C" void foo(int *a); + +int main() { + int *a = (int *)malloc(sizeof(int)); + free(a); + foo(a); + return 0; +} +#endif \ No newline at end of file diff --git a/compiler-rt/test/dfsan/origin_endianness.c b/compiler-rt/test/dfsan/origin_endianness.c new file mode 100644 index 0000000000000..a73dcda080e79 --- /dev/null +++ b/compiler-rt/test/dfsan/origin_endianness.c @@ -0,0 +1,37 @@ +// RUN: %clang_dfsan -gmlt -mllvm -dfsan-track-origins=1 %s -o %t && \ +// RUN: %run %t >%t.out 2>&1 +// RUN: FileCheck %s < %t.out +// +// Test origin tracking is accurate in terms of endianness. + +#include + +typedef uint64_t FULL_TYPE; +typedef uint32_t HALF_TYPE; + +__attribute__((noinline)) FULL_TYPE foo(FULL_TYPE a, FULL_TYPE b) { + return a + b; +} + +int main(int argc, char *argv[]) { + FULL_TYPE a = 1; + FULL_TYPE b = 10; + dfsan_set_label(4, (HALF_TYPE *)&a, sizeof(HALF_TYPE)); + FULL_TYPE c = foo(a, b); + dfsan_print_origin_trace(&c, NULL); + dfsan_print_origin_trace((HALF_TYPE *)&c, NULL); +} + +// CHECK: Taint value 0x4 {{.*}} origin tracking () +// CHECK: Origin value: {{.*}}, Taint value was stored to memory at +// CHECK: #0 {{.*}} in main {{.*}}origin_endianness.c:[[@LINE-7]] + +// CHECK: Origin value: {{.*}}, Taint value was created at +// CHECK: #0 {{.*}} in main {{.*}}origin_endianness.c:[[@LINE-11]] + +// CHECK: Taint value 0x4 {{.*}} origin tracking () +// CHECK: Origin value: {{.*}}, Taint value was stored to memory at +// CHECK: #0 {{.*}} in main {{.*}}origin_endianness.c:[[@LINE-14]] + +// CHECK: Origin value: {{.*}}, Taint value was created at +// CHECK: #0 {{.*}} in main {{.*}}origin_endianness.c:[[@LINE-18]] diff --git a/libcxx/docs/Status/Cxx23Issues.csv b/libcxx/docs/Status/Cxx23Issues.csv index 5a68b51ec85fb..389e1ad254a74 100644 --- a/libcxx/docs/Status/Cxx23Issues.csv +++ b/libcxx/docs/Status/Cxx23Issues.csv @@ -58,7 +58,7 @@ "`LWG3495 `__","``constexpr launder`` makes pointers to inactive members of unions usable","2021-02 (Virtual)","|Nothing To Do|","","`#104316 `__","" "`LWG3500 `__","``join_view::iterator::operator->()`` is bogus","2021-02 (Virtual)","|Complete|","14","`#104318 `__","" "`LWG3502 `__","``elements_view`` should not be allowed to return dangling reference","2021-02 (Virtual)","|Complete|","16","`#104319 `__","" -"`LWG3505 `__","``split_view::outer-iterator::operator++`` misspecified","2021-02 (Virtual)","","","`#104320 `__","" +"`LWG3505 `__","``split_view::outer-iterator::operator++`` misspecified","2021-02 (Virtual)","|Complete|","15","`#104320 `__","" "","","","","","","" "`LWG2774 `__","``std::function`` construction vs assignment","2021-06 (Virtual)","","","`#104321 `__","" "`LWG2818 `__","``::std::`` everywhere rule needs tweaking","2021-06 (Virtual)","|Nothing To Do|","","`#104322 `__","" diff --git a/libcxx/test/std/ranges/range.adaptors/range.lazy.split/range.lazy.split.outer/increment.pass.cpp b/libcxx/test/std/ranges/range.adaptors/range.lazy.split/range.lazy.split.outer/increment.pass.cpp index 4d765d71407f5..b557346588306 100644 --- a/libcxx/test/std/ranges/range.adaptors/range.lazy.split/range.lazy.split.outer/increment.pass.cpp +++ b/libcxx/test/std/ranges/range.adaptors/range.lazy.split/range.lazy.split.outer/increment.pass.cpp @@ -75,6 +75,56 @@ constexpr bool test() { } } + // LWG3505 + { + using namespace std::string_view_literals; + + { // Motivational example + auto v = std::views::lazy_split("xxyx"sv, "xy"sv); + + { + auto i = v.begin(); + assert(std::ranges::equal(*i, "x"s)); + + decltype(auto) i2 = ++i; + static_assert(std::is_lvalue_reference_v); + assert(std::ranges::equal(*i2, "x"s)); + } + + { + auto i = v.begin(); + assert(std::ranges::equal(*i, "x"s)); + + decltype(auto) i2 = i++; + static_assert(!std::is_reference_v); + assert(std::ranges::equal(*i2, "x"s)); + assert(std::ranges::equal(*i, "x"s)); + } + } + { + auto v = std::views::lazy_split("zzht"sv, "zh"sv); + + { + auto i = v.begin(); + assert(std::ranges::equal(*i, "z"s)); + + decltype(auto) i2 = ++i; + static_assert(std::is_lvalue_reference_v); + assert(std::ranges::equal(*i2, "t"s)); + } + + { + auto i = v.begin(); + assert(std::ranges::equal(*i, "z"s)); + + decltype(auto) i2 = i++; + static_assert(!std::is_reference_v); + assert(std::ranges::equal(*i2, "z"s)); + assert(std::ranges::equal(*i, "t"s)); + } + } + } + return true; } diff --git a/llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp b/llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp index 0679e8dfc3b43..aaf7a921c2981 100644 --- a/llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp +++ b/llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp @@ -4017,28 +4017,6 @@ bool SIRegisterInfo::isProperlyAlignedRC(const TargetRegisterClass &RC) const { return true; } -const TargetRegisterClass * -SIRegisterInfo::getProperlyAlignedRC(const TargetRegisterClass *RC) const { - if (!RC || !ST.needsAlignedVGPRs()) - return RC; - - unsigned Size = getRegSizeInBits(*RC); - if (Size <= 32) - return RC; - - if (RC == &AMDGPU::VS_64RegClass) - return &AMDGPU::VS_64_Align2RegClass; - - if (isVGPRClass(RC)) - return getAlignedVGPRClassForBitWidth(Size); - if (isAGPRClass(RC)) - return getAlignedAGPRClassForBitWidth(Size); - if (isVectorSuperClass(RC)) - return getAlignedVectorSuperClassForBitWidth(Size); - - return RC; -} - ArrayRef SIRegisterInfo::getAllSGPR128(const MachineFunction &MF) const { return ArrayRef(AMDGPU::SGPR_128RegClass.begin(), ST.getMaxNumSGPRs(MF) / 4); diff --git a/llvm/lib/Target/AMDGPU/SIRegisterInfo.h b/llvm/lib/Target/AMDGPU/SIRegisterInfo.h index a6af25dfd7d6f..1402291539ff8 100644 --- a/llvm/lib/Target/AMDGPU/SIRegisterInfo.h +++ b/llvm/lib/Target/AMDGPU/SIRegisterInfo.h @@ -439,11 +439,6 @@ class SIRegisterInfo final : public AMDGPUGenRegisterInfo { // the subtarget. bool isProperlyAlignedRC(const TargetRegisterClass &RC) const; - // Given \p RC returns corresponding aligned register class if required - // by the subtarget. - const TargetRegisterClass * - getProperlyAlignedRC(const TargetRegisterClass *RC) const; - /// Return all SGPR128 which satisfy the waves per execution unit requirement /// of the subtarget. ArrayRef getAllSGPR128(const MachineFunction &MF) const; diff --git a/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp b/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp index 04a97606cb7f8..894a07e6b68c2 100644 --- a/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp +++ b/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp @@ -3958,3 +3958,41 @@ bool HexagonTargetLowering::isMaskAndCmp0FoldingBeneficial( return false; return Mask->getValue().isPowerOf2(); } + +// Check if the result of the node is only used as a return value, as +// otherwise we can't perform a tail-call. +bool HexagonTargetLowering::isUsedByReturnOnly(SDNode *N, + SDValue &Chain) const { + if (N->getNumValues() != 1) + return false; + if (!N->hasNUsesOfValue(1, 0)) + return false; + + SDNode *Copy = *N->user_begin(); + + if (Copy->getOpcode() == ISD::BITCAST) { + return isUsedByReturnOnly(Copy, Chain); + } + + if (Copy->getOpcode() != ISD::CopyToReg) { + return false; + } + + // If the ISD::CopyToReg has a glue operand, we conservatively assume it + // isn't safe to perform a tail call. + if (Copy->getOperand(Copy->getNumOperands() - 1).getValueType() == MVT::Glue) + return false; + + // The copy must be used by a HexagonISD::RET_GLUE, and nothing else. + bool HasRet = false; + for (SDNode *Node : Copy->users()) { + if (Node->getOpcode() != HexagonISD::RET_GLUE) + return false; + HasRet = true; + } + if (!HasRet) + return false; + + Chain = Copy->getOperand(0); + return true; +} diff --git a/llvm/lib/Target/Hexagon/HexagonISelLowering.h b/llvm/lib/Target/Hexagon/HexagonISelLowering.h index 4ac3e7671592a..f4d2a79051c10 100644 --- a/llvm/lib/Target/Hexagon/HexagonISelLowering.h +++ b/llvm/lib/Target/Hexagon/HexagonISelLowering.h @@ -162,6 +162,8 @@ class HexagonTargetLowering : public TargetLowering { bool isMaskAndCmp0FoldingBeneficial(const Instruction &AndI) const override; + bool isUsedByReturnOnly(SDNode *N, SDValue &Chain) const override; + /// Return true if an FMA operation is faster than a pair of mul and add /// instructions. fmuladd intrinsics will be expanded to FMAs when this /// method returns true (and FMAs are legal), otherwise fmuladd is diff --git a/llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp b/llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp index cc53ec2c0f2f3..e984ac46fca4a 100644 --- a/llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp +++ b/llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp @@ -2191,8 +2191,16 @@ std::pair DFSanFunction::loadShadowFast( // and then the entire shadow for the second origin pointer (which will be // chosen by combineOrigins() iff the least-significant half of the wide // shadow was empty but the other half was not). - Value *WideShadowLo = IRB.CreateShl( - WideShadow, ConstantInt::get(WideShadowTy, WideShadowBitWidth / 2)); + Value *WideShadowLo = + F->getParent()->getDataLayout().isLittleEndian() + ? IRB.CreateShl( + WideShadow, + ConstantInt::get(WideShadowTy, WideShadowBitWidth / 2)) + : IRB.CreateAnd( + WideShadow, + ConstantInt::get(WideShadowTy, + (1 - (1 << (WideShadowBitWidth / 2))) + << (WideShadowBitWidth / 2))); Shadows.push_back(WideShadow); Origins.push_back(DFS.loadNextOrigin(Pos, OriginAlign, &OriginAddr)); diff --git a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp index ceeece41782f4..7f1ac41f2e212 100644 --- a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp +++ b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp @@ -2720,34 +2720,55 @@ struct MemorySanitizerVisitor : public InstVisitor { // of elements. // // For example, suppose we have: - // VectorA: - // VectorB: - // ReductionFactor: 3. + // VectorA: + // VectorB: + // ReductionFactor: 3 + // Shards: 1 // The output would be: - // + // + // + // If we have: + // VectorA: + // VectorB: + // ReductionFactor: 2 + // Shards: 2 + // then a and be each have 2 "shards", resulting in the output being + // interleaved: + // // // This is convenient for instrumenting horizontal add/sub. // For bitwise OR on "vertical" pairs, see maybeHandleSimpleNomemIntrinsic(). Value *horizontalReduce(IntrinsicInst &I, unsigned ReductionFactor, - Value *VectorA, Value *VectorB) { + unsigned Shards, Value *VectorA, Value *VectorB) { assert(isa(VectorA->getType())); - unsigned TotalNumElems = + unsigned NumElems = cast(VectorA->getType())->getNumElements(); + [[maybe_unused]] unsigned TotalNumElems = NumElems; if (VectorB) { assert(VectorA->getType() == VectorB->getType()); - TotalNumElems = TotalNumElems * 2; + TotalNumElems *= 2; } - assert(TotalNumElems % ReductionFactor == 0); + assert(NumElems % (ReductionFactor * Shards) == 0); Value *Or = nullptr; IRBuilder<> IRB(&I); for (unsigned i = 0; i < ReductionFactor; i++) { SmallVector Mask; - for (unsigned X = 0; X < TotalNumElems; X += ReductionFactor) - Mask.push_back(X + i); + + for (unsigned j = 0; j < Shards; j++) { + unsigned Offset = NumElems / Shards * j; + + for (unsigned X = 0; X < NumElems / Shards; X += ReductionFactor) + Mask.push_back(Offset + X + i); + + if (VectorB) { + for (unsigned X = 0; X < NumElems / Shards; X += ReductionFactor) + Mask.push_back(NumElems + Offset + X + i); + } + } Value *Masked; if (VectorB) @@ -2769,7 +2790,7 @@ struct MemorySanitizerVisitor : public InstVisitor { /// /// e.g., <2 x i32> @llvm.aarch64.neon.saddlp.v2i32.v4i16(<4 x i16>) /// <16 x i8> @llvm.aarch64.neon.addp.v16i8(<16 x i8>, <16 x i8>) - void handlePairwiseShadowOrIntrinsic(IntrinsicInst &I) { + void handlePairwiseShadowOrIntrinsic(IntrinsicInst &I, unsigned Shards) { assert(I.arg_size() == 1 || I.arg_size() == 2); assert(I.getType()->isVectorTy()); @@ -2792,8 +2813,8 @@ struct MemorySanitizerVisitor : public InstVisitor { if (I.arg_size() == 2) SecondArgShadow = getShadow(&I, 1); - Value *OrShadow = horizontalReduce(I, /*ReductionFactor=*/2, FirstArgShadow, - SecondArgShadow); + Value *OrShadow = horizontalReduce(I, /*ReductionFactor=*/2, Shards, + FirstArgShadow, SecondArgShadow); OrShadow = CreateShadowCast(IRB, OrShadow, getShadowTy(&I)); @@ -2808,7 +2829,7 @@ struct MemorySanitizerVisitor : public InstVisitor { /// conceptually operates on /// (<4 x i16> [[VAR1]], <4 x i16> [[VAR2]]) /// and can be handled with ReinterpretElemWidth == 16. - void handlePairwiseShadowOrIntrinsic(IntrinsicInst &I, + void handlePairwiseShadowOrIntrinsic(IntrinsicInst &I, unsigned Shards, int ReinterpretElemWidth) { assert(I.arg_size() == 1 || I.arg_size() == 2); @@ -2852,8 +2873,8 @@ struct MemorySanitizerVisitor : public InstVisitor { SecondArgShadow = IRB.CreateBitCast(SecondArgShadow, ReinterpretShadowTy); } - Value *OrShadow = horizontalReduce(I, /*ReductionFactor=*/2, FirstArgShadow, - SecondArgShadow); + Value *OrShadow = horizontalReduce(I, /*ReductionFactor=*/2, Shards, + FirstArgShadow, SecondArgShadow); OrShadow = CreateShadowCast(IRB, OrShadow, getShadowTy(&I)); @@ -5925,15 +5946,20 @@ struct MemorySanitizerVisitor : public InstVisitor { /*ZeroPurifies=*/true, /*EltSizeInBits=*/16); break; - // TODO: Dot Product of BF16 Pairs Accumulated Into Packed Single - // Precision - // <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128 - // (<4 x float>, <8 x bfloat>, <8 x bfloat>) - // <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256 - // (<8 x float>, <16 x bfloat>, <16 x bfloat>) - // <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512 - // (<16 x float>, <32 x bfloat>, <32 x bfloat>) - // handleVectorPmaddIntrinsic() currently only handles integer types. + // Dot Product of BF16 Pairs Accumulated Into Packed Single + // Precision + // <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128 + // (<4 x float>, <8 x bfloat>, <8 x bfloat>) + // <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256 + // (<8 x float>, <16 x bfloat>, <16 x bfloat>) + // <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512 + // (<16 x float>, <32 x bfloat>, <32 x bfloat>) + case Intrinsic::x86_avx512bf16_dpbf16ps_128: + case Intrinsic::x86_avx512bf16_dpbf16ps_256: + case Intrinsic::x86_avx512bf16_dpbf16ps_512: + handleVectorPmaddIntrinsic(I, /*ReductionFactor=*/2, + /*ZeroPurifies=*/false); + break; case Intrinsic::x86_sse_cmp_ss: case Intrinsic::x86_sse2_cmp_sd: @@ -6031,48 +6057,66 @@ struct MemorySanitizerVisitor : public InstVisitor { // Packed Horizontal Add/Subtract case Intrinsic::x86_ssse3_phadd_w: case Intrinsic::x86_ssse3_phadd_w_128: - case Intrinsic::x86_avx2_phadd_w: case Intrinsic::x86_ssse3_phsub_w: case Intrinsic::x86_ssse3_phsub_w_128: - case Intrinsic::x86_avx2_phsub_w: { - handlePairwiseShadowOrIntrinsic(I, /*ReinterpretElemWidth=*/16); + handlePairwiseShadowOrIntrinsic(I, /*Shards=*/1, + /*ReinterpretElemWidth=*/16); + break; + + case Intrinsic::x86_avx2_phadd_w: + case Intrinsic::x86_avx2_phsub_w: + // TODO: Shards = 2 + handlePairwiseShadowOrIntrinsic(I, /*Shards=*/1, + /*ReinterpretElemWidth=*/16); break; - } // Packed Horizontal Add/Subtract case Intrinsic::x86_ssse3_phadd_d: case Intrinsic::x86_ssse3_phadd_d_128: - case Intrinsic::x86_avx2_phadd_d: case Intrinsic::x86_ssse3_phsub_d: case Intrinsic::x86_ssse3_phsub_d_128: - case Intrinsic::x86_avx2_phsub_d: { - handlePairwiseShadowOrIntrinsic(I, /*ReinterpretElemWidth=*/32); + handlePairwiseShadowOrIntrinsic(I, /*Shards=*/1, + /*ReinterpretElemWidth=*/32); + break; + + case Intrinsic::x86_avx2_phadd_d: + case Intrinsic::x86_avx2_phsub_d: + // TODO: Shards = 2 + handlePairwiseShadowOrIntrinsic(I, /*Shards=*/1, + /*ReinterpretElemWidth=*/32); break; - } // Packed Horizontal Add/Subtract and Saturate case Intrinsic::x86_ssse3_phadd_sw: case Intrinsic::x86_ssse3_phadd_sw_128: - case Intrinsic::x86_avx2_phadd_sw: case Intrinsic::x86_ssse3_phsub_sw: case Intrinsic::x86_ssse3_phsub_sw_128: - case Intrinsic::x86_avx2_phsub_sw: { - handlePairwiseShadowOrIntrinsic(I, /*ReinterpretElemWidth=*/16); + handlePairwiseShadowOrIntrinsic(I, /*Shards=*/1, + /*ReinterpretElemWidth=*/16); + break; + + case Intrinsic::x86_avx2_phadd_sw: + case Intrinsic::x86_avx2_phsub_sw: + // TODO: Shards = 2 + handlePairwiseShadowOrIntrinsic(I, /*Shards=*/1, + /*ReinterpretElemWidth=*/16); break; - } // Packed Single/Double Precision Floating-Point Horizontal Add case Intrinsic::x86_sse3_hadd_ps: case Intrinsic::x86_sse3_hadd_pd: - case Intrinsic::x86_avx_hadd_pd_256: - case Intrinsic::x86_avx_hadd_ps_256: case Intrinsic::x86_sse3_hsub_ps: case Intrinsic::x86_sse3_hsub_pd: + handlePairwiseShadowOrIntrinsic(I, /*Shards=*/1); + break; + + case Intrinsic::x86_avx_hadd_pd_256: + case Intrinsic::x86_avx_hadd_ps_256: case Intrinsic::x86_avx_hsub_pd_256: - case Intrinsic::x86_avx_hsub_ps_256: { - handlePairwiseShadowOrIntrinsic(I); + case Intrinsic::x86_avx_hsub_ps_256: + // TODO: Shards = 2 + handlePairwiseShadowOrIntrinsic(I, /*Shards=*/1); break; - } case Intrinsic::x86_avx_maskstore_ps: case Intrinsic::x86_avx_maskstore_pd: @@ -6455,7 +6499,7 @@ struct MemorySanitizerVisitor : public InstVisitor { // Add Long Pairwise case Intrinsic::aarch64_neon_saddlp: case Intrinsic::aarch64_neon_uaddlp: { - handlePairwiseShadowOrIntrinsic(I); + handlePairwiseShadowOrIntrinsic(I, /*Shards=*/1); break; } diff --git a/llvm/lib/Transforms/Utils/ProfileVerify.cpp b/llvm/lib/Transforms/Utils/ProfileVerify.cpp index c578b4b839258..149c0879edcdd 100644 --- a/llvm/lib/Transforms/Utils/ProfileVerify.cpp +++ b/llvm/lib/Transforms/Utils/ProfileVerify.cpp @@ -102,9 +102,11 @@ bool ProfileInjector::inject() { for (auto &BB : F) { if (AnnotateSelect) { for (auto &I : BB) { - if (isa(I) && !I.getMetadata(LLVMContext::MD_prof)) - setBranchWeights(I, {SelectTrueWeight, SelectFalseWeight}, - /*IsExpected=*/false); + if (auto *SI = dyn_cast(&I)) + if (!SI->getCondition()->getType()->isVectorTy() && + !I.getMetadata(LLVMContext::MD_prof)) + setBranchWeights(I, {SelectTrueWeight, SelectFalseWeight}, + /*IsExpected=*/false); } } auto *Term = getTerminatorBenefitingFromMDProf(BB); @@ -185,9 +187,11 @@ PreservedAnalyses ProfileVerifierPass::run(Function &F, for (const auto &BB : F) { if (AnnotateSelect) { for (const auto &I : BB) - if (isa(I) && !I.getMetadata(LLVMContext::MD_prof)) - F.getContext().emitError( - "Profile verification failed: select annotation missing"); + if (auto *SI = dyn_cast(&I)) + if (!SI->getCondition()->getType()->isVectorTy() && + !I.getMetadata(LLVMContext::MD_prof)) + F.getContext().emitError( + "Profile verification failed: select annotation missing"); } if (const auto *Term = ProfileInjector::getTerminatorBenefitingFromMDProf(BB)) diff --git a/llvm/test/CodeGen/Hexagon/fast-math-libcalls.ll b/llvm/test/CodeGen/Hexagon/fast-math-libcalls.ll index 6bc60132d3e6a..831ab0a980368 100644 --- a/llvm/test/CodeGen/Hexagon/fast-math-libcalls.ll +++ b/llvm/test/CodeGen/Hexagon/fast-math-libcalls.ll @@ -9,15 +9,8 @@ define float @fast_sqrt_f32(float %x) { ; CHECK-LABEL: fast_sqrt_f32: ; CHECK: .cfi_startproc ; CHECK-NEXT: // %bb.0: -; CHECK-NEXT: .cfi_def_cfa r30, 8 -; CHECK-NEXT: .cfi_offset r31, -4 -; CHECK-NEXT: .cfi_offset r30, -8 ; CHECK-NEXT: { -; CHECK-NEXT: call __hexagon_fast2_sqrtf -; CHECK-NEXT: allocframe(r29,#0):raw -; CHECK-NEXT: } -; CHECK-NEXT: { -; CHECK-NEXT: r31:30 = dealloc_return(r30):raw +; CHECK-NEXT: jump __hexagon_fast2_sqrtf ; CHECK-NEXT: } %result = call nnan ninf nsz afn float @llvm.sqrt.f32(float %x) ret float %result @@ -27,15 +20,8 @@ define double @fast_sqrt_f64(double %x) { ; CHECK-LABEL: fast_sqrt_f64: ; CHECK: .cfi_startproc ; CHECK-NEXT: // %bb.0: -; CHECK-NEXT: .cfi_def_cfa r30, 8 -; CHECK-NEXT: .cfi_offset r31, -4 -; CHECK-NEXT: .cfi_offset r30, -8 -; CHECK-NEXT: { -; CHECK-NEXT: call __hexagon_fast2_sqrtdf2 -; CHECK-NEXT: allocframe(r29,#0):raw -; CHECK-NEXT: } ; CHECK-NEXT: { -; CHECK-NEXT: r31:30 = dealloc_return(r30):raw +; CHECK-NEXT: jump __hexagon_fast2_sqrtdf2 ; CHECK-NEXT: } %result = call nnan ninf nsz afn double @llvm.sqrt.f64(double %x) ret double %result @@ -61,15 +47,8 @@ define double @fast_add_f64(double %x, double %y) { ; CHECK-LABEL: fast_add_f64: ; CHECK: .cfi_startproc ; CHECK-NEXT: // %bb.0: -; CHECK-NEXT: .cfi_def_cfa r30, 8 -; CHECK-NEXT: .cfi_offset r31, -4 -; CHECK-NEXT: .cfi_offset r30, -8 -; CHECK-NEXT: { -; CHECK-NEXT: call __hexagon_fast_adddf3 -; CHECK-NEXT: allocframe(r29,#0):raw -; CHECK-NEXT: } ; CHECK-NEXT: { -; CHECK-NEXT: r31:30 = dealloc_return(r30):raw +; CHECK-NEXT: jump __hexagon_fast_adddf3 ; CHECK-NEXT: } %result = fadd nnan ninf nsz afn double %x, %y ret double %result @@ -95,15 +74,8 @@ define double @fast_sub_f64(double %x, double %y) { ; CHECK-LABEL: fast_sub_f64: ; CHECK: .cfi_startproc ; CHECK-NEXT: // %bb.0: -; CHECK-NEXT: .cfi_def_cfa r30, 8 -; CHECK-NEXT: .cfi_offset r31, -4 -; CHECK-NEXT: .cfi_offset r30, -8 ; CHECK-NEXT: { -; CHECK-NEXT: call __hexagon_fast_subdf3 -; CHECK-NEXT: allocframe(r29,#0):raw -; CHECK-NEXT: } -; CHECK-NEXT: { -; CHECK-NEXT: r31:30 = dealloc_return(r30):raw +; CHECK-NEXT: jump __hexagon_fast_subdf3 ; CHECK-NEXT: } %result = fsub nnan ninf nsz afn double %x, %y ret double %result @@ -129,15 +101,8 @@ define double @fast_mul_f64(double %x, double %y) { ; CHECK-LABEL: fast_mul_f64: ; CHECK: .cfi_startproc ; CHECK-NEXT: // %bb.0: -; CHECK-NEXT: .cfi_def_cfa r30, 8 -; CHECK-NEXT: .cfi_offset r31, -4 -; CHECK-NEXT: .cfi_offset r30, -8 ; CHECK-NEXT: { -; CHECK-NEXT: call __hexagon_fast_muldf3 -; CHECK-NEXT: allocframe(r29,#0):raw -; CHECK-NEXT: } -; CHECK-NEXT: { -; CHECK-NEXT: r31:30 = dealloc_return(r30):raw +; CHECK-NEXT: jump __hexagon_fast_muldf3 ; CHECK-NEXT: } %result = fmul nnan ninf nsz afn double %x, %y ret double %result @@ -194,15 +159,8 @@ define double @fast_div_f64(double %x, double %y) { ; CHECK-LABEL: fast_div_f64: ; CHECK: .cfi_startproc ; CHECK-NEXT: // %bb.0: -; CHECK-NEXT: .cfi_def_cfa r30, 8 -; CHECK-NEXT: .cfi_offset r31, -4 -; CHECK-NEXT: .cfi_offset r30, -8 -; CHECK-NEXT: { -; CHECK-NEXT: call __hexagon_fast_divdf3 -; CHECK-NEXT: allocframe(r29,#0):raw -; CHECK-NEXT: } ; CHECK-NEXT: { -; CHECK-NEXT: r31:30 = dealloc_return(r30):raw +; CHECK-NEXT: jump __hexagon_fast_divdf3 ; CHECK-NEXT: } %result = fdiv nnan ninf nsz afn double %x, %y ret double %result @@ -217,15 +175,8 @@ define float @sqrt_f32__afn(float %x) { ; CHECK-LABEL: sqrt_f32__afn: ; CHECK: .cfi_startproc ; CHECK-NEXT: // %bb.0: -; CHECK-NEXT: .cfi_def_cfa r30, 8 -; CHECK-NEXT: .cfi_offset r31, -4 -; CHECK-NEXT: .cfi_offset r30, -8 -; CHECK-NEXT: { -; CHECK-NEXT: call __hexagon_sqrtf -; CHECK-NEXT: allocframe(r29,#0):raw -; CHECK-NEXT: } ; CHECK-NEXT: { -; CHECK-NEXT: r31:30 = dealloc_return(r30):raw +; CHECK-NEXT: jump __hexagon_sqrtf ; CHECK-NEXT: } %result = call afn float @llvm.sqrt.f32(float %x) ret float %result @@ -235,15 +186,8 @@ define float @sqrt_f32__afn_ninf(float %x) { ; CHECK-LABEL: sqrt_f32__afn_ninf: ; CHECK: .cfi_startproc ; CHECK-NEXT: // %bb.0: -; CHECK-NEXT: .cfi_def_cfa r30, 8 -; CHECK-NEXT: .cfi_offset r31, -4 -; CHECK-NEXT: .cfi_offset r30, -8 ; CHECK-NEXT: { -; CHECK-NEXT: call __hexagon_sqrtf -; CHECK-NEXT: allocframe(r29,#0):raw -; CHECK-NEXT: } -; CHECK-NEXT: { -; CHECK-NEXT: r31:30 = dealloc_return(r30):raw +; CHECK-NEXT: jump __hexagon_sqrtf ; CHECK-NEXT: } %result = call afn ninf float @llvm.sqrt.f32(float %x) ret float %result @@ -253,15 +197,8 @@ define float @sqrt_f32__afn_nnan(float %x) { ; CHECK-LABEL: sqrt_f32__afn_nnan: ; CHECK: .cfi_startproc ; CHECK-NEXT: // %bb.0: -; CHECK-NEXT: .cfi_def_cfa r30, 8 -; CHECK-NEXT: .cfi_offset r31, -4 -; CHECK-NEXT: .cfi_offset r30, -8 ; CHECK-NEXT: { -; CHECK-NEXT: call __hexagon_sqrtf -; CHECK-NEXT: allocframe(r29,#0):raw -; CHECK-NEXT: } -; CHECK-NEXT: { -; CHECK-NEXT: r31:30 = dealloc_return(r30):raw +; CHECK-NEXT: jump __hexagon_sqrtf ; CHECK-NEXT: } %result = call afn nnan float @llvm.sqrt.f32(float %x) ret float %result @@ -271,15 +208,8 @@ define float @sqrt_f32__nnan(float %x) { ; CHECK-LABEL: sqrt_f32__nnan: ; CHECK: .cfi_startproc ; CHECK-NEXT: // %bb.0: -; CHECK-NEXT: .cfi_def_cfa r30, 8 -; CHECK-NEXT: .cfi_offset r31, -4 -; CHECK-NEXT: .cfi_offset r30, -8 -; CHECK-NEXT: { -; CHECK-NEXT: call __hexagon_sqrtf -; CHECK-NEXT: allocframe(r29,#0):raw -; CHECK-NEXT: } ; CHECK-NEXT: { -; CHECK-NEXT: r31:30 = dealloc_return(r30):raw +; CHECK-NEXT: jump __hexagon_sqrtf ; CHECK-NEXT: } %result = call nnan float @llvm.sqrt.f32(float %x) ret float %result @@ -289,15 +219,8 @@ define float @sqrt_f32_nnan_ninf_afn(float %x) { ; CHECK-LABEL: sqrt_f32_nnan_ninf_afn: ; CHECK: .cfi_startproc ; CHECK-NEXT: // %bb.0: -; CHECK-NEXT: .cfi_def_cfa r30, 8 -; CHECK-NEXT: .cfi_offset r31, -4 -; CHECK-NEXT: .cfi_offset r30, -8 -; CHECK-NEXT: { -; CHECK-NEXT: call __hexagon_sqrtf -; CHECK-NEXT: allocframe(r29,#0):raw -; CHECK-NEXT: } ; CHECK-NEXT: { -; CHECK-NEXT: r31:30 = dealloc_return(r30):raw +; CHECK-NEXT: jump __hexagon_sqrtf ; CHECK-NEXT: } %result = call nnan ninf afn float @llvm.sqrt.f32(float %x) ret float %result diff --git a/llvm/test/CodeGen/Hexagon/fminmax-v67.ll b/llvm/test/CodeGen/Hexagon/fminmax-v67.ll index ba4fcb5afdba3..8ce34210c38cf 100644 --- a/llvm/test/CodeGen/Hexagon/fminmax-v67.ll +++ b/llvm/test/CodeGen/Hexagon/fminmax-v67.ll @@ -2,7 +2,7 @@ ; CHECK-LABEL: t1 -; CHECK: call fmax +; CHECK: jump fmax define dso_local double @t1(double %a, double %b) local_unnamed_addr { entry: @@ -11,7 +11,7 @@ entry: } ; CHECK-LABEL: t2 -; CHECK: call fmin +; CHECK: jump fmin define dso_local double @t2(double %a, double %b) local_unnamed_addr { entry: @@ -20,7 +20,7 @@ entry: } ; CHECK-LABEL: t3 -; CHECK: call fmaxf +; CHECK: jump fmaxf define dso_local float @t3(float %a, float %b) local_unnamed_addr { entry: @@ -29,7 +29,7 @@ entry: } ; CHECK-LABEL: t4 -; CHECK: call fminf +; CHECK: jump fminf define dso_local float @t4(float %a, float %b) local_unnamed_addr { entry: diff --git a/llvm/test/CodeGen/Hexagon/fminmax.ll b/llvm/test/CodeGen/Hexagon/fminmax.ll index 2aae79e6b9bf3..e134168aefdfd 100644 --- a/llvm/test/CodeGen/Hexagon/fminmax.ll +++ b/llvm/test/CodeGen/Hexagon/fminmax.ll @@ -4,7 +4,7 @@ target datalayout = "e-m:e-p:32:32:32-a:0-n16:32-i64:64:64-i32:32:32-i16:16:16-i target triple = "hexagon" ; CHECK-LABEL: cfminf -; CHECK: call fminf +; CHECK: jump fminf define float @cfminf(float %x, float %y) #0 { entry: %call = tail call float @fminf(float %x, float %y) #1 @@ -12,7 +12,7 @@ entry: } ; CHECK-LABEL: cfmaxf -; CHECK: call fmaxf +; CHECK: jump fmaxf define float @cfmaxf(float %x, float %y) #0 { entry: %call = tail call float @fmaxf(float %x, float %y) #1 @@ -20,7 +20,7 @@ entry: } ; CHECK-LABEL: minnum -; CHECK: call fminf +; CHECK: jump fminf define float @minnum(float %x, float %y) #0 { entry: %call = tail call float @llvm.minnum.f32(float %x, float %y) #1 @@ -28,7 +28,7 @@ entry: } ; CHECK-LABEL: maxnum -; CHECK: call fmaxf +; CHECK: jump fmaxf define float @maxnum(float %x, float %y) #0 { entry: %call = tail call float @llvm.maxnum.f32(float %x, float %y) #1 diff --git a/llvm/test/CodeGen/Hexagon/fp16.ll b/llvm/test/CodeGen/Hexagon/fp16.ll index 2f933c92e42b8..40211f2a1a656 100644 --- a/llvm/test/CodeGen/Hexagon/fp16.ll +++ b/llvm/test/CodeGen/Hexagon/fp16.ll @@ -13,7 +13,7 @@ ; Validate that we generate correct lib calls to convert fp16 ;CHECK-LABEL: @test1 -;CHECK: call __extendhfsf2 +;CHECK: jump __extendhfsf2 ;CHECK: r0 = memuh define dso_local float @test1(ptr nocapture readonly %a) local_unnamed_addr #0 { entry: diff --git a/llvm/test/CodeGen/Hexagon/inline-division-space.ll b/llvm/test/CodeGen/Hexagon/inline-division-space.ll index c1937600d47bf..711a00bb9de5b 100644 --- a/llvm/test/CodeGen/Hexagon/inline-division-space.ll +++ b/llvm/test/CodeGen/Hexagon/inline-division-space.ll @@ -14,7 +14,7 @@ entry: ; Function Attrs: optsize define dso_local float @testFloat(float %a, float %b) local_unnamed_addr #0 { entry: -;CHECK: call __hexagon_divsf3 +;CHECK: jump __hexagon_divsf3 %div = fdiv float %a, %b ret float %div } @@ -22,7 +22,7 @@ entry: ; Function Attrs: optsize define dso_local double @testDouble(double %a, double %b) local_unnamed_addr #0 { entry: -;CHECK: call __hexagon_divdf3 +;CHECK: jump __hexagon_divdf3 %div = fdiv double %a, %b ret double %div } diff --git a/llvm/test/CodeGen/Hexagon/inline-division.ll b/llvm/test/CodeGen/Hexagon/inline-division.ll index 5eb97a002b0f4..b1b5fde53b3c6 100644 --- a/llvm/test/CodeGen/Hexagon/inline-division.ll +++ b/llvm/test/CodeGen/Hexagon/inline-division.ll @@ -23,7 +23,7 @@ entry: define dso_local double @testDouble(double %a, double %b) local_unnamed_addr { entry: -;CHECK: call __hexagon_divdf3 +;CHECK: jump __hexagon_divdf3 %div = fdiv double %a, %b ret double %div } diff --git a/llvm/test/CodeGen/Hexagon/libcall_tail.ll b/llvm/test/CodeGen/Hexagon/libcall_tail.ll new file mode 100644 index 0000000000000..2ea95abe8055a --- /dev/null +++ b/llvm/test/CodeGen/Hexagon/libcall_tail.ll @@ -0,0 +1,88 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6 +; Test that libcalls used only by return are tail called. +; This tests non-float libcalls +; RUN: llc -march=hexagon -verify-machineinstrs < %s | FileCheck %s + +define i32 @udiv(i32 %a, i32 %b) nounwind { +; CHECK-LABEL: udiv: +; CHECK: // %bb.0: +; CHECK-NEXT: { +; CHECK-NEXT: jump __hexagon_udivsi3 +; CHECK-NEXT: } + %1 = udiv i32 %a, %b + ret i32 %1 +} + +define i32 @udivconstby(i32 %a) nounwind { +; CHECK-LABEL: udivconstby: +; CHECK: // %bb.0: +; CHECK-NEXT: { +; CHECK-NEXT: r1:0 = combine(r0,#10) +; CHECK-NEXT: jump __hexagon_udivsi3 +; CHECK-NEXT: } + %1 = udiv i32 10, %a + ret i32 %1 +} + +define i32 @sdiv(i32 %a, i32 %b) nounwind { +; CHECK-LABEL: sdiv: +; CHECK: // %bb.0: +; CHECK-NEXT: { +; CHECK-NEXT: jump __hexagon_divsi3 +; CHECK-NEXT: } + %1 = sdiv i32 %a, %b + ret i32 %1 +} + +define i32 @sdivconstby(i32 %a) nounwind { +; CHECK-LABEL: sdivconstby: +; CHECK: // %bb.0: +; CHECK-NEXT: { +; CHECK-NEXT: r1:0 = combine(r0,#10) +; CHECK-NEXT: jump __hexagon_divsi3 +; CHECK-NEXT: } + %1 = sdiv i32 10, %a + ret i32 %1 +} + +define i32 @urem(i32 %a, i32 %b) nounwind { +; CHECK-LABEL: urem: +; CHECK: // %bb.0: +; CHECK-NEXT: { +; CHECK-NEXT: jump __hexagon_umodsi3 +; CHECK-NEXT: } + %1 = urem i32 %a, %b + ret i32 %1 +} + +define i32 @uremconstby(i32 %a) nounwind { +; CHECK-LABEL: uremconstby: +; CHECK: // %bb.0: +; CHECK-NEXT: { +; CHECK-NEXT: r1:0 = combine(r0,#10) +; CHECK-NEXT: jump __hexagon_umodsi3 +; CHECK-NEXT: } + %1 = urem i32 10, %a + ret i32 %1 +} + +define i32 @srem(i32 %a, i32 %b) nounwind { +; CHECK-LABEL: srem: +; CHECK: // %bb.0: +; CHECK-NEXT: { +; CHECK-NEXT: jump __hexagon_modsi3 +; CHECK-NEXT: } + %1 = srem i32 %a, %b + ret i32 %1 +} + +define i32 @sremconstby(i32 %a) nounwind { +; CHECK-LABEL: sremconstby: +; CHECK: // %bb.0: +; CHECK-NEXT: { +; CHECK-NEXT: r1:0 = combine(r0,#10) +; CHECK-NEXT: jump __hexagon_modsi3 +; CHECK-NEXT: } + %1 = srem i32 10, %a + ret i32 %1 +} diff --git a/llvm/test/CodeGen/Hexagon/llvm.exp10.ll b/llvm/test/CodeGen/Hexagon/llvm.exp10.ll index b5fcc4151225a..cd94d328f1fee 100644 --- a/llvm/test/CodeGen/Hexagon/llvm.exp10.ll +++ b/llvm/test/CodeGen/Hexagon/llvm.exp10.ll @@ -66,11 +66,7 @@ define float @exp10_f32(float %x) #0 { ; CHECK-LABEL: exp10_f32: ; CHECK: // %bb.0: ; CHECK-NEXT: { -; CHECK-NEXT: call exp10f -; CHECK-NEXT: allocframe(r29,#0):raw -; CHECK-NEXT: } -; CHECK-NEXT: { -; CHECK-NEXT: r31:30 = dealloc_return(r30):raw +; CHECK-NEXT: jump exp10f ; CHECK-NEXT: } %r = call float @llvm.exp10.f32(float %x) ret float %r @@ -103,11 +99,7 @@ define double @exp10_f64(double %x) #0 { ; CHECK-LABEL: exp10_f64: ; CHECK: // %bb.0: ; CHECK-NEXT: { -; CHECK-NEXT: call exp10 -; CHECK-NEXT: allocframe(r29,#0):raw -; CHECK-NEXT: } -; CHECK-NEXT: { -; CHECK-NEXT: r31:30 = dealloc_return(r30):raw +; CHECK-NEXT: jump exp10 ; CHECK-NEXT: } %r = call double @llvm.exp10.f64(double %x) ret double %r diff --git a/llvm/test/Instrumentation/MemorySanitizer/X86/avx512bf16-intrinsics.ll b/llvm/test/Instrumentation/MemorySanitizer/X86/avx512bf16-intrinsics.ll index 877fe5fe4b393..d32a1d0034c84 100644 --- a/llvm/test/Instrumentation/MemorySanitizer/X86/avx512bf16-intrinsics.ll +++ b/llvm/test/Instrumentation/MemorySanitizer/X86/avx512bf16-intrinsics.ll @@ -6,7 +6,6 @@ ; Strictly handled: ; - llvm.x86.avx512bf16.cvtne2ps2bf16.512(<16 x float> %A, <16 x float> %B) ; - llvm.x86.avx512bf16.cvtneps2bf16.512(<16 x float> %A) -; - llvm.x86.avx512bf16.dpbf16ps.512(<16 x float> %E, <32 x bfloat> %A, <32 x bfloat> %B) ; ; Heuristically handled: (none) @@ -241,25 +240,20 @@ define <16 x float> @test_mm512_dpbf16ps_512(<16 x float> %E, <32 x bfloat> %A, ; CHECK-LABEL: define <16 x float> @test_mm512_dpbf16ps_512( ; CHECK-SAME: <16 x float> [[E:%.*]], <32 x bfloat> [[A:%.*]], <32 x bfloat> [[B:%.*]]) local_unnamed_addr #[[ATTR1]] { ; CHECK-NEXT: [[ENTRY:.*:]] -; CHECK-NEXT: [[TMP0:%.*]] = load <16 x i32>, ptr @__msan_param_tls, align 8 ; CHECK-NEXT: [[TMP1:%.*]] = load <32 x i16>, ptr getelementptr (i8, ptr @__msan_param_tls, i64 64), align 8 ; CHECK-NEXT: [[TMP2:%.*]] = load <32 x i16>, ptr getelementptr (i8, ptr @__msan_param_tls, i64 128), align 8 +; CHECK-NEXT: [[TMP11:%.*]] = load <16 x i32>, ptr @__msan_param_tls, align 8 ; CHECK-NEXT: call void @llvm.donothing() -; CHECK-NEXT: [[TMP3:%.*]] = bitcast <16 x i32> [[TMP0]] to i512 -; CHECK-NEXT: [[_MSCMP:%.*]] = icmp ne i512 [[TMP3]], 0 -; CHECK-NEXT: [[TMP4:%.*]] = bitcast <32 x i16> [[TMP1]] to i512 -; CHECK-NEXT: [[_MSCMP1:%.*]] = icmp ne i512 [[TMP4]], 0 -; CHECK-NEXT: [[_MSOR:%.*]] = or i1 [[_MSCMP]], [[_MSCMP1]] -; CHECK-NEXT: [[TMP5:%.*]] = bitcast <32 x i16> [[TMP2]] to i512 -; CHECK-NEXT: [[_MSCMP2:%.*]] = icmp ne i512 [[TMP5]], 0 -; CHECK-NEXT: [[_MSOR3:%.*]] = or i1 [[_MSOR]], [[_MSCMP2]] -; CHECK-NEXT: br i1 [[_MSOR3]], label %[[BB6:.*]], label %[[BB7:.*]], !prof [[PROF1]] -; CHECK: [[BB6]]: -; CHECK-NEXT: call void @__msan_warning_noreturn() #[[ATTR4]] -; CHECK-NEXT: unreachable -; CHECK: [[BB7]]: +; CHECK-NEXT: [[TMP3:%.*]] = icmp ne <32 x i16> [[TMP1]], zeroinitializer +; CHECK-NEXT: [[TMP4:%.*]] = icmp ne <32 x i16> [[TMP2]], zeroinitializer +; CHECK-NEXT: [[TMP5:%.*]] = or <32 x i1> [[TMP3]], [[TMP4]] +; CHECK-NEXT: [[TMP6:%.*]] = sext <32 x i1> [[TMP5]] to <32 x i16> +; CHECK-NEXT: [[TMP7:%.*]] = bitcast <32 x i16> [[TMP6]] to <16 x i32> +; CHECK-NEXT: [[TMP12:%.*]] = icmp ne <16 x i32> [[TMP7]], zeroinitializer +; CHECK-NEXT: [[TMP9:%.*]] = sext <16 x i1> [[TMP12]] to <16 x i32> +; CHECK-NEXT: [[TMP10:%.*]] = or <16 x i32> [[TMP9]], [[TMP11]] ; CHECK-NEXT: [[TMP8:%.*]] = tail call <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512(<16 x float> [[E]], <32 x bfloat> [[A]], <32 x bfloat> [[B]]) -; CHECK-NEXT: store <16 x i32> zeroinitializer, ptr @__msan_retval_tls, align 8 +; CHECK-NEXT: store <16 x i32> [[TMP10]], ptr @__msan_retval_tls, align 8 ; CHECK-NEXT: ret <16 x float> [[TMP8]] ; entry: @@ -271,31 +265,26 @@ define <16 x float> @test_mm512_maskz_dpbf16ps_512(<16 x float> %E, <32 x bfloat ; CHECK-LABEL: define <16 x float> @test_mm512_maskz_dpbf16ps_512( ; CHECK-SAME: <16 x float> [[E:%.*]], <32 x bfloat> [[A:%.*]], <32 x bfloat> [[B:%.*]], i16 zeroext [[U:%.*]]) local_unnamed_addr #[[ATTR1]] { ; CHECK-NEXT: [[ENTRY:.*:]] -; CHECK-NEXT: [[TMP0:%.*]] = load <16 x i32>, ptr @__msan_param_tls, align 8 ; CHECK-NEXT: [[TMP1:%.*]] = load <32 x i16>, ptr getelementptr (i8, ptr @__msan_param_tls, i64 64), align 8 ; CHECK-NEXT: [[TMP2:%.*]] = load <32 x i16>, ptr getelementptr (i8, ptr @__msan_param_tls, i64 128), align 8 +; CHECK-NEXT: [[TMP18:%.*]] = load <16 x i32>, ptr @__msan_param_tls, align 8 ; CHECK-NEXT: [[TMP3:%.*]] = load i16, ptr getelementptr (i8, ptr @__msan_param_tls, i64 192), align 8 ; CHECK-NEXT: call void @llvm.donothing() -; CHECK-NEXT: [[TMP4:%.*]] = bitcast <16 x i32> [[TMP0]] to i512 -; CHECK-NEXT: [[_MSCMP:%.*]] = icmp ne i512 [[TMP4]], 0 -; CHECK-NEXT: [[TMP5:%.*]] = bitcast <32 x i16> [[TMP1]] to i512 -; CHECK-NEXT: [[_MSCMP1:%.*]] = icmp ne i512 [[TMP5]], 0 -; CHECK-NEXT: [[_MSOR:%.*]] = or i1 [[_MSCMP]], [[_MSCMP1]] -; CHECK-NEXT: [[TMP6:%.*]] = bitcast <32 x i16> [[TMP2]] to i512 -; CHECK-NEXT: [[_MSCMP2:%.*]] = icmp ne i512 [[TMP6]], 0 -; CHECK-NEXT: [[_MSOR3:%.*]] = or i1 [[_MSOR]], [[_MSCMP2]] -; CHECK-NEXT: br i1 [[_MSOR3]], label %[[BB7:.*]], label %[[BB8:.*]], !prof [[PROF1]] -; CHECK: [[BB7]]: -; CHECK-NEXT: call void @__msan_warning_noreturn() #[[ATTR4]] -; CHECK-NEXT: unreachable -; CHECK: [[BB8]]: +; CHECK-NEXT: [[TMP4:%.*]] = icmp ne <32 x i16> [[TMP1]], zeroinitializer +; CHECK-NEXT: [[TMP5:%.*]] = icmp ne <32 x i16> [[TMP2]], zeroinitializer +; CHECK-NEXT: [[TMP6:%.*]] = or <32 x i1> [[TMP4]], [[TMP5]] +; CHECK-NEXT: [[TMP7:%.*]] = sext <32 x i1> [[TMP6]] to <32 x i16> +; CHECK-NEXT: [[TMP8:%.*]] = bitcast <32 x i16> [[TMP7]] to <16 x i32> +; CHECK-NEXT: [[TMP19:%.*]] = icmp ne <16 x i32> [[TMP8]], zeroinitializer +; CHECK-NEXT: [[TMP20:%.*]] = sext <16 x i1> [[TMP19]] to <16 x i32> +; CHECK-NEXT: [[TMP21:%.*]] = or <16 x i32> [[TMP20]], [[TMP18]] ; CHECK-NEXT: [[TMP9:%.*]] = tail call <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512(<16 x float> [[E]], <32 x bfloat> [[A]], <32 x bfloat> [[B]]) ; CHECK-NEXT: [[TMP10:%.*]] = bitcast i16 [[TMP3]] to <16 x i1> ; CHECK-NEXT: [[TMP11:%.*]] = bitcast i16 [[U]] to <16 x i1> -; CHECK-NEXT: [[TMP12:%.*]] = select <16 x i1> [[TMP11]], <16 x i32> zeroinitializer, <16 x i32> zeroinitializer +; CHECK-NEXT: [[TMP12:%.*]] = select <16 x i1> [[TMP11]], <16 x i32> [[TMP21]], <16 x i32> zeroinitializer ; CHECK-NEXT: [[TMP13:%.*]] = bitcast <16 x float> [[TMP9]] to <16 x i32> ; CHECK-NEXT: [[TMP14:%.*]] = xor <16 x i32> [[TMP13]], zeroinitializer -; CHECK-NEXT: [[TMP15:%.*]] = or <16 x i32> [[TMP14]], zeroinitializer +; CHECK-NEXT: [[TMP15:%.*]] = or <16 x i32> [[TMP14]], [[TMP21]] ; CHECK-NEXT: [[TMP16:%.*]] = or <16 x i32> [[TMP15]], zeroinitializer ; CHECK-NEXT: [[_MSPROP_SELECT:%.*]] = select <16 x i1> [[TMP10]], <16 x i32> [[TMP16]], <16 x i32> [[TMP12]] ; CHECK-NEXT: [[TMP17:%.*]] = select <16 x i1> [[TMP11]], <16 x float> [[TMP9]], <16 x float> zeroinitializer @@ -312,32 +301,27 @@ define <16 x float> @test_mm512_mask_dpbf16ps_512(i16 zeroext %U, <16 x float> % ; CHECK-LABEL: define <16 x float> @test_mm512_mask_dpbf16ps_512( ; CHECK-SAME: i16 zeroext [[U:%.*]], <16 x float> [[E:%.*]], <32 x bfloat> [[A:%.*]], <32 x bfloat> [[B:%.*]]) local_unnamed_addr #[[ATTR1]] { ; CHECK-NEXT: [[ENTRY:.*:]] -; CHECK-NEXT: [[TMP0:%.*]] = load <16 x i32>, ptr getelementptr (i8, ptr @__msan_param_tls, i64 8), align 8 ; CHECK-NEXT: [[TMP1:%.*]] = load <32 x i16>, ptr getelementptr (i8, ptr @__msan_param_tls, i64 72), align 8 ; CHECK-NEXT: [[TMP2:%.*]] = load <32 x i16>, ptr getelementptr (i8, ptr @__msan_param_tls, i64 136), align 8 +; CHECK-NEXT: [[TMP0:%.*]] = load <16 x i32>, ptr getelementptr (i8, ptr @__msan_param_tls, i64 8), align 8 ; CHECK-NEXT: [[TMP3:%.*]] = load i16, ptr @__msan_param_tls, align 8 ; CHECK-NEXT: call void @llvm.donothing() -; CHECK-NEXT: [[TMP4:%.*]] = bitcast <16 x i32> [[TMP0]] to i512 -; CHECK-NEXT: [[_MSCMP:%.*]] = icmp ne i512 [[TMP4]], 0 -; CHECK-NEXT: [[TMP5:%.*]] = bitcast <32 x i16> [[TMP1]] to i512 -; CHECK-NEXT: [[_MSCMP1:%.*]] = icmp ne i512 [[TMP5]], 0 -; CHECK-NEXT: [[_MSOR:%.*]] = or i1 [[_MSCMP]], [[_MSCMP1]] -; CHECK-NEXT: [[TMP6:%.*]] = bitcast <32 x i16> [[TMP2]] to i512 -; CHECK-NEXT: [[_MSCMP2:%.*]] = icmp ne i512 [[TMP6]], 0 -; CHECK-NEXT: [[_MSOR3:%.*]] = or i1 [[_MSOR]], [[_MSCMP2]] -; CHECK-NEXT: br i1 [[_MSOR3]], label %[[BB7:.*]], label %[[BB8:.*]], !prof [[PROF1]] -; CHECK: [[BB7]]: -; CHECK-NEXT: call void @__msan_warning_noreturn() #[[ATTR4]] -; CHECK-NEXT: unreachable -; CHECK: [[BB8]]: +; CHECK-NEXT: [[TMP4:%.*]] = icmp ne <32 x i16> [[TMP1]], zeroinitializer +; CHECK-NEXT: [[TMP5:%.*]] = icmp ne <32 x i16> [[TMP2]], zeroinitializer +; CHECK-NEXT: [[TMP6:%.*]] = or <32 x i1> [[TMP4]], [[TMP5]] +; CHECK-NEXT: [[TMP7:%.*]] = sext <32 x i1> [[TMP6]] to <32 x i16> +; CHECK-NEXT: [[TMP8:%.*]] = bitcast <32 x i16> [[TMP7]] to <16 x i32> +; CHECK-NEXT: [[TMP19:%.*]] = icmp ne <16 x i32> [[TMP8]], zeroinitializer +; CHECK-NEXT: [[TMP20:%.*]] = sext <16 x i1> [[TMP19]] to <16 x i32> +; CHECK-NEXT: [[TMP21:%.*]] = or <16 x i32> [[TMP20]], [[TMP0]] ; CHECK-NEXT: [[TMP9:%.*]] = tail call <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512(<16 x float> [[E]], <32 x bfloat> [[A]], <32 x bfloat> [[B]]) ; CHECK-NEXT: [[TMP10:%.*]] = bitcast i16 [[TMP3]] to <16 x i1> ; CHECK-NEXT: [[TMP11:%.*]] = bitcast i16 [[U]] to <16 x i1> -; CHECK-NEXT: [[TMP12:%.*]] = select <16 x i1> [[TMP11]], <16 x i32> zeroinitializer, <16 x i32> [[TMP0]] +; CHECK-NEXT: [[TMP12:%.*]] = select <16 x i1> [[TMP11]], <16 x i32> [[TMP21]], <16 x i32> [[TMP0]] ; CHECK-NEXT: [[TMP13:%.*]] = bitcast <16 x float> [[TMP9]] to <16 x i32> ; CHECK-NEXT: [[TMP14:%.*]] = bitcast <16 x float> [[E]] to <16 x i32> ; CHECK-NEXT: [[TMP15:%.*]] = xor <16 x i32> [[TMP13]], [[TMP14]] -; CHECK-NEXT: [[TMP16:%.*]] = or <16 x i32> [[TMP15]], zeroinitializer +; CHECK-NEXT: [[TMP16:%.*]] = or <16 x i32> [[TMP15]], [[TMP21]] ; CHECK-NEXT: [[TMP17:%.*]] = or <16 x i32> [[TMP16]], [[TMP0]] ; CHECK-NEXT: [[_MSPROP_SELECT:%.*]] = select <16 x i1> [[TMP10]], <16 x i32> [[TMP17]], <16 x i32> [[TMP12]] ; CHECK-NEXT: [[TMP18:%.*]] = select <16 x i1> [[TMP11]], <16 x float> [[TMP9]], <16 x float> [[E]] diff --git a/llvm/test/Instrumentation/MemorySanitizer/X86/avx512bf16-vl-intrinsics.ll b/llvm/test/Instrumentation/MemorySanitizer/X86/avx512bf16-vl-intrinsics.ll index 904614e961d6c..a46d1ac9e2ab8 100644 --- a/llvm/test/Instrumentation/MemorySanitizer/X86/avx512bf16-vl-intrinsics.ll +++ b/llvm/test/Instrumentation/MemorySanitizer/X86/avx512bf16-vl-intrinsics.ll @@ -7,8 +7,6 @@ ; - llvm.x86.avx512bf16.cvtne2ps2bf16.128(<4 x float> %A, <4 x float> %B) ; - llvm.x86.avx512bf16.cvtne2ps2bf16.256(<8 x float> %A, <8 x float> %B) ; - llvm.x86.avx512bf16.cvtneps2bf16.256(<8 x float> %A) -; - llvm.x86.avx512bf16.dpbf16ps.128(<4 x float> %E, <8 x bfloat> %A, <8 x bfloat> %B) -; - llvm.x86.avx512bf16.dpbf16ps.256(<8 x float> %E, <16 x bfloat> %A, <16 x bfloat> %B) ; - llvm.x86.avx512bf16.mask.cvtneps2bf16.128(<4 x float> %A, <8 x bfloat> %6, <4 x i1> %4) ; ; Heuristically handled: (none) @@ -492,25 +490,20 @@ define <8 x float> @test_mm256_dpbf16ps_256(<8 x float> %E, <16 x bfloat> %A, <1 ; CHECK-LABEL: define <8 x float> @test_mm256_dpbf16ps_256( ; CHECK-SAME: <8 x float> [[E:%.*]], <16 x bfloat> [[A:%.*]], <16 x bfloat> [[B:%.*]]) local_unnamed_addr #[[ATTR1]] { ; CHECK-NEXT: [[ENTRY:.*:]] -; CHECK-NEXT: [[TMP0:%.*]] = load <8 x i32>, ptr @__msan_param_tls, align 8 ; CHECK-NEXT: [[TMP1:%.*]] = load <16 x i16>, ptr getelementptr (i8, ptr @__msan_param_tls, i64 32), align 8 ; CHECK-NEXT: [[TMP2:%.*]] = load <16 x i16>, ptr getelementptr (i8, ptr @__msan_param_tls, i64 64), align 8 +; CHECK-NEXT: [[TMP11:%.*]] = load <8 x i32>, ptr @__msan_param_tls, align 8 ; CHECK-NEXT: call void @llvm.donothing() -; CHECK-NEXT: [[TMP3:%.*]] = bitcast <8 x i32> [[TMP0]] to i256 -; CHECK-NEXT: [[_MSCMP:%.*]] = icmp ne i256 [[TMP3]], 0 -; CHECK-NEXT: [[TMP4:%.*]] = bitcast <16 x i16> [[TMP1]] to i256 -; CHECK-NEXT: [[_MSCMP1:%.*]] = icmp ne i256 [[TMP4]], 0 -; CHECK-NEXT: [[_MSOR:%.*]] = or i1 [[_MSCMP]], [[_MSCMP1]] -; CHECK-NEXT: [[TMP5:%.*]] = bitcast <16 x i16> [[TMP2]] to i256 -; CHECK-NEXT: [[_MSCMP2:%.*]] = icmp ne i256 [[TMP5]], 0 -; CHECK-NEXT: [[_MSOR3:%.*]] = or i1 [[_MSOR]], [[_MSCMP2]] -; CHECK-NEXT: br i1 [[_MSOR3]], label %[[BB6:.*]], label %[[BB7:.*]], !prof [[PROF1]] -; CHECK: [[BB6]]: -; CHECK-NEXT: call void @__msan_warning_noreturn() #[[ATTR5]] -; CHECK-NEXT: unreachable -; CHECK: [[BB7]]: +; CHECK-NEXT: [[TMP3:%.*]] = icmp ne <16 x i16> [[TMP1]], zeroinitializer +; CHECK-NEXT: [[TMP4:%.*]] = icmp ne <16 x i16> [[TMP2]], zeroinitializer +; CHECK-NEXT: [[TMP5:%.*]] = or <16 x i1> [[TMP3]], [[TMP4]] +; CHECK-NEXT: [[TMP6:%.*]] = sext <16 x i1> [[TMP5]] to <16 x i16> +; CHECK-NEXT: [[TMP7:%.*]] = bitcast <16 x i16> [[TMP6]] to <8 x i32> +; CHECK-NEXT: [[TMP12:%.*]] = icmp ne <8 x i32> [[TMP7]], zeroinitializer +; CHECK-NEXT: [[TMP9:%.*]] = sext <8 x i1> [[TMP12]] to <8 x i32> +; CHECK-NEXT: [[TMP10:%.*]] = or <8 x i32> [[TMP9]], [[TMP11]] ; CHECK-NEXT: [[TMP8:%.*]] = tail call <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256(<8 x float> [[E]], <16 x bfloat> [[A]], <16 x bfloat> [[B]]) -; CHECK-NEXT: store <8 x i32> zeroinitializer, ptr @__msan_retval_tls, align 8 +; CHECK-NEXT: store <8 x i32> [[TMP10]], ptr @__msan_retval_tls, align 8 ; CHECK-NEXT: ret <8 x float> [[TMP8]] ; entry: @@ -522,31 +515,26 @@ define <8 x float> @test_mm256_maskz_dpbf16ps_256(<8 x float> %E, <16 x bfloat> ; CHECK-LABEL: define <8 x float> @test_mm256_maskz_dpbf16ps_256( ; CHECK-SAME: <8 x float> [[E:%.*]], <16 x bfloat> [[A:%.*]], <16 x bfloat> [[B:%.*]], i8 zeroext [[U:%.*]]) local_unnamed_addr #[[ATTR1]] { ; CHECK-NEXT: [[ENTRY:.*:]] -; CHECK-NEXT: [[TMP0:%.*]] = load <8 x i32>, ptr @__msan_param_tls, align 8 ; CHECK-NEXT: [[TMP1:%.*]] = load <16 x i16>, ptr getelementptr (i8, ptr @__msan_param_tls, i64 32), align 8 ; CHECK-NEXT: [[TMP2:%.*]] = load <16 x i16>, ptr getelementptr (i8, ptr @__msan_param_tls, i64 64), align 8 +; CHECK-NEXT: [[TMP18:%.*]] = load <8 x i32>, ptr @__msan_param_tls, align 8 ; CHECK-NEXT: [[TMP3:%.*]] = load i8, ptr getelementptr (i8, ptr @__msan_param_tls, i64 96), align 8 ; CHECK-NEXT: call void @llvm.donothing() -; CHECK-NEXT: [[TMP4:%.*]] = bitcast <8 x i32> [[TMP0]] to i256 -; CHECK-NEXT: [[_MSCMP:%.*]] = icmp ne i256 [[TMP4]], 0 -; CHECK-NEXT: [[TMP5:%.*]] = bitcast <16 x i16> [[TMP1]] to i256 -; CHECK-NEXT: [[_MSCMP1:%.*]] = icmp ne i256 [[TMP5]], 0 -; CHECK-NEXT: [[_MSOR:%.*]] = or i1 [[_MSCMP]], [[_MSCMP1]] -; CHECK-NEXT: [[TMP6:%.*]] = bitcast <16 x i16> [[TMP2]] to i256 -; CHECK-NEXT: [[_MSCMP2:%.*]] = icmp ne i256 [[TMP6]], 0 -; CHECK-NEXT: [[_MSOR3:%.*]] = or i1 [[_MSOR]], [[_MSCMP2]] -; CHECK-NEXT: br i1 [[_MSOR3]], label %[[BB7:.*]], label %[[BB8:.*]], !prof [[PROF1]] -; CHECK: [[BB7]]: -; CHECK-NEXT: call void @__msan_warning_noreturn() #[[ATTR5]] -; CHECK-NEXT: unreachable -; CHECK: [[BB8]]: +; CHECK-NEXT: [[TMP4:%.*]] = icmp ne <16 x i16> [[TMP1]], zeroinitializer +; CHECK-NEXT: [[TMP5:%.*]] = icmp ne <16 x i16> [[TMP2]], zeroinitializer +; CHECK-NEXT: [[TMP6:%.*]] = or <16 x i1> [[TMP4]], [[TMP5]] +; CHECK-NEXT: [[TMP7:%.*]] = sext <16 x i1> [[TMP6]] to <16 x i16> +; CHECK-NEXT: [[TMP8:%.*]] = bitcast <16 x i16> [[TMP7]] to <8 x i32> +; CHECK-NEXT: [[TMP19:%.*]] = icmp ne <8 x i32> [[TMP8]], zeroinitializer +; CHECK-NEXT: [[TMP20:%.*]] = sext <8 x i1> [[TMP19]] to <8 x i32> +; CHECK-NEXT: [[TMP21:%.*]] = or <8 x i32> [[TMP20]], [[TMP18]] ; CHECK-NEXT: [[TMP9:%.*]] = tail call <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256(<8 x float> [[E]], <16 x bfloat> [[A]], <16 x bfloat> [[B]]) ; CHECK-NEXT: [[TMP10:%.*]] = bitcast i8 [[TMP3]] to <8 x i1> ; CHECK-NEXT: [[TMP11:%.*]] = bitcast i8 [[U]] to <8 x i1> -; CHECK-NEXT: [[TMP12:%.*]] = select <8 x i1> [[TMP11]], <8 x i32> zeroinitializer, <8 x i32> zeroinitializer +; CHECK-NEXT: [[TMP12:%.*]] = select <8 x i1> [[TMP11]], <8 x i32> [[TMP21]], <8 x i32> zeroinitializer ; CHECK-NEXT: [[TMP13:%.*]] = bitcast <8 x float> [[TMP9]] to <8 x i32> ; CHECK-NEXT: [[TMP14:%.*]] = xor <8 x i32> [[TMP13]], zeroinitializer -; CHECK-NEXT: [[TMP15:%.*]] = or <8 x i32> [[TMP14]], zeroinitializer +; CHECK-NEXT: [[TMP15:%.*]] = or <8 x i32> [[TMP14]], [[TMP21]] ; CHECK-NEXT: [[TMP16:%.*]] = or <8 x i32> [[TMP15]], zeroinitializer ; CHECK-NEXT: [[_MSPROP_SELECT:%.*]] = select <8 x i1> [[TMP10]], <8 x i32> [[TMP16]], <8 x i32> [[TMP12]] ; CHECK-NEXT: [[TMP17:%.*]] = select <8 x i1> [[TMP11]], <8 x float> [[TMP9]], <8 x float> zeroinitializer @@ -563,32 +551,27 @@ define <8 x float> @test_mm256_mask_dpbf16ps_256(i8 zeroext %U, <8 x float> %E, ; CHECK-LABEL: define <8 x float> @test_mm256_mask_dpbf16ps_256( ; CHECK-SAME: i8 zeroext [[U:%.*]], <8 x float> [[E:%.*]], <16 x bfloat> [[A:%.*]], <16 x bfloat> [[B:%.*]]) local_unnamed_addr #[[ATTR1]] { ; CHECK-NEXT: [[ENTRY:.*:]] -; CHECK-NEXT: [[TMP0:%.*]] = load <8 x i32>, ptr getelementptr (i8, ptr @__msan_param_tls, i64 8), align 8 ; CHECK-NEXT: [[TMP1:%.*]] = load <16 x i16>, ptr getelementptr (i8, ptr @__msan_param_tls, i64 40), align 8 ; CHECK-NEXT: [[TMP2:%.*]] = load <16 x i16>, ptr getelementptr (i8, ptr @__msan_param_tls, i64 72), align 8 +; CHECK-NEXT: [[TMP0:%.*]] = load <8 x i32>, ptr getelementptr (i8, ptr @__msan_param_tls, i64 8), align 8 ; CHECK-NEXT: [[TMP3:%.*]] = load i8, ptr @__msan_param_tls, align 8 ; CHECK-NEXT: call void @llvm.donothing() -; CHECK-NEXT: [[TMP4:%.*]] = bitcast <8 x i32> [[TMP0]] to i256 -; CHECK-NEXT: [[_MSCMP:%.*]] = icmp ne i256 [[TMP4]], 0 -; CHECK-NEXT: [[TMP5:%.*]] = bitcast <16 x i16> [[TMP1]] to i256 -; CHECK-NEXT: [[_MSCMP1:%.*]] = icmp ne i256 [[TMP5]], 0 -; CHECK-NEXT: [[_MSOR:%.*]] = or i1 [[_MSCMP]], [[_MSCMP1]] -; CHECK-NEXT: [[TMP6:%.*]] = bitcast <16 x i16> [[TMP2]] to i256 -; CHECK-NEXT: [[_MSCMP2:%.*]] = icmp ne i256 [[TMP6]], 0 -; CHECK-NEXT: [[_MSOR3:%.*]] = or i1 [[_MSOR]], [[_MSCMP2]] -; CHECK-NEXT: br i1 [[_MSOR3]], label %[[BB7:.*]], label %[[BB8:.*]], !prof [[PROF1]] -; CHECK: [[BB7]]: -; CHECK-NEXT: call void @__msan_warning_noreturn() #[[ATTR5]] -; CHECK-NEXT: unreachable -; CHECK: [[BB8]]: +; CHECK-NEXT: [[TMP4:%.*]] = icmp ne <16 x i16> [[TMP1]], zeroinitializer +; CHECK-NEXT: [[TMP5:%.*]] = icmp ne <16 x i16> [[TMP2]], zeroinitializer +; CHECK-NEXT: [[TMP6:%.*]] = or <16 x i1> [[TMP4]], [[TMP5]] +; CHECK-NEXT: [[TMP7:%.*]] = sext <16 x i1> [[TMP6]] to <16 x i16> +; CHECK-NEXT: [[TMP8:%.*]] = bitcast <16 x i16> [[TMP7]] to <8 x i32> +; CHECK-NEXT: [[TMP19:%.*]] = icmp ne <8 x i32> [[TMP8]], zeroinitializer +; CHECK-NEXT: [[TMP20:%.*]] = sext <8 x i1> [[TMP19]] to <8 x i32> +; CHECK-NEXT: [[TMP21:%.*]] = or <8 x i32> [[TMP20]], [[TMP0]] ; CHECK-NEXT: [[TMP9:%.*]] = tail call <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256(<8 x float> [[E]], <16 x bfloat> [[A]], <16 x bfloat> [[B]]) ; CHECK-NEXT: [[TMP10:%.*]] = bitcast i8 [[TMP3]] to <8 x i1> ; CHECK-NEXT: [[TMP11:%.*]] = bitcast i8 [[U]] to <8 x i1> -; CHECK-NEXT: [[TMP12:%.*]] = select <8 x i1> [[TMP11]], <8 x i32> zeroinitializer, <8 x i32> [[TMP0]] +; CHECK-NEXT: [[TMP12:%.*]] = select <8 x i1> [[TMP11]], <8 x i32> [[TMP21]], <8 x i32> [[TMP0]] ; CHECK-NEXT: [[TMP13:%.*]] = bitcast <8 x float> [[TMP9]] to <8 x i32> ; CHECK-NEXT: [[TMP14:%.*]] = bitcast <8 x float> [[E]] to <8 x i32> ; CHECK-NEXT: [[TMP15:%.*]] = xor <8 x i32> [[TMP13]], [[TMP14]] -; CHECK-NEXT: [[TMP16:%.*]] = or <8 x i32> [[TMP15]], zeroinitializer +; CHECK-NEXT: [[TMP16:%.*]] = or <8 x i32> [[TMP15]], [[TMP21]] ; CHECK-NEXT: [[TMP17:%.*]] = or <8 x i32> [[TMP16]], [[TMP0]] ; CHECK-NEXT: [[_MSPROP_SELECT:%.*]] = select <8 x i1> [[TMP10]], <8 x i32> [[TMP17]], <8 x i32> [[TMP12]] ; CHECK-NEXT: [[TMP18:%.*]] = select <8 x i1> [[TMP11]], <8 x float> [[TMP9]], <8 x float> [[E]] @@ -608,25 +591,20 @@ define <4 x float> @test_mm128_dpbf16ps_128(<4 x float> %E, <8 x bfloat> %A, <8 ; CHECK-LABEL: define <4 x float> @test_mm128_dpbf16ps_128( ; CHECK-SAME: <4 x float> [[E:%.*]], <8 x bfloat> [[A:%.*]], <8 x bfloat> [[B:%.*]]) local_unnamed_addr #[[ATTR1]] { ; CHECK-NEXT: [[ENTRY:.*:]] -; CHECK-NEXT: [[TMP0:%.*]] = load <4 x i32>, ptr @__msan_param_tls, align 8 ; CHECK-NEXT: [[TMP1:%.*]] = load <8 x i16>, ptr getelementptr (i8, ptr @__msan_param_tls, i64 16), align 8 ; CHECK-NEXT: [[TMP2:%.*]] = load <8 x i16>, ptr getelementptr (i8, ptr @__msan_param_tls, i64 32), align 8 +; CHECK-NEXT: [[TMP11:%.*]] = load <4 x i32>, ptr @__msan_param_tls, align 8 ; CHECK-NEXT: call void @llvm.donothing() -; CHECK-NEXT: [[TMP3:%.*]] = bitcast <4 x i32> [[TMP0]] to i128 -; CHECK-NEXT: [[_MSCMP:%.*]] = icmp ne i128 [[TMP3]], 0 -; CHECK-NEXT: [[TMP4:%.*]] = bitcast <8 x i16> [[TMP1]] to i128 -; CHECK-NEXT: [[_MSCMP1:%.*]] = icmp ne i128 [[TMP4]], 0 -; CHECK-NEXT: [[_MSOR:%.*]] = or i1 [[_MSCMP]], [[_MSCMP1]] -; CHECK-NEXT: [[TMP5:%.*]] = bitcast <8 x i16> [[TMP2]] to i128 -; CHECK-NEXT: [[_MSCMP2:%.*]] = icmp ne i128 [[TMP5]], 0 -; CHECK-NEXT: [[_MSOR3:%.*]] = or i1 [[_MSOR]], [[_MSCMP2]] -; CHECK-NEXT: br i1 [[_MSOR3]], label %[[BB6:.*]], label %[[BB7:.*]], !prof [[PROF1]] -; CHECK: [[BB6]]: -; CHECK-NEXT: call void @__msan_warning_noreturn() #[[ATTR5]] -; CHECK-NEXT: unreachable -; CHECK: [[BB7]]: +; CHECK-NEXT: [[TMP3:%.*]] = icmp ne <8 x i16> [[TMP1]], zeroinitializer +; CHECK-NEXT: [[TMP4:%.*]] = icmp ne <8 x i16> [[TMP2]], zeroinitializer +; CHECK-NEXT: [[TMP5:%.*]] = or <8 x i1> [[TMP3]], [[TMP4]] +; CHECK-NEXT: [[TMP6:%.*]] = sext <8 x i1> [[TMP5]] to <8 x i16> +; CHECK-NEXT: [[TMP7:%.*]] = bitcast <8 x i16> [[TMP6]] to <4 x i32> +; CHECK-NEXT: [[TMP12:%.*]] = icmp ne <4 x i32> [[TMP7]], zeroinitializer +; CHECK-NEXT: [[TMP9:%.*]] = sext <4 x i1> [[TMP12]] to <4 x i32> +; CHECK-NEXT: [[TMP10:%.*]] = or <4 x i32> [[TMP9]], [[TMP11]] ; CHECK-NEXT: [[TMP8:%.*]] = tail call <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128(<4 x float> [[E]], <8 x bfloat> [[A]], <8 x bfloat> [[B]]) -; CHECK-NEXT: store <4 x i32> zeroinitializer, ptr @__msan_retval_tls, align 8 +; CHECK-NEXT: store <4 x i32> [[TMP10]], ptr @__msan_retval_tls, align 8 ; CHECK-NEXT: ret <4 x float> [[TMP8]] ; entry: @@ -638,31 +616,26 @@ define <4 x float> @test_mm128_maskz_dpbf16ps_128(<4 x float> %E, <8 x bfloat> % ; CHECK-LABEL: define <4 x float> @test_mm128_maskz_dpbf16ps_128( ; CHECK-SAME: <4 x float> [[E:%.*]], <8 x bfloat> [[A:%.*]], <8 x bfloat> [[B:%.*]], i4 zeroext [[U:%.*]]) local_unnamed_addr #[[ATTR1]] { ; CHECK-NEXT: [[ENTRY:.*:]] -; CHECK-NEXT: [[TMP0:%.*]] = load <4 x i32>, ptr @__msan_param_tls, align 8 ; CHECK-NEXT: [[TMP1:%.*]] = load <8 x i16>, ptr getelementptr (i8, ptr @__msan_param_tls, i64 16), align 8 ; CHECK-NEXT: [[TMP2:%.*]] = load <8 x i16>, ptr getelementptr (i8, ptr @__msan_param_tls, i64 32), align 8 +; CHECK-NEXT: [[TMP18:%.*]] = load <4 x i32>, ptr @__msan_param_tls, align 8 ; CHECK-NEXT: [[TMP3:%.*]] = load i4, ptr getelementptr (i8, ptr @__msan_param_tls, i64 48), align 8 ; CHECK-NEXT: call void @llvm.donothing() -; CHECK-NEXT: [[TMP4:%.*]] = bitcast <4 x i32> [[TMP0]] to i128 -; CHECK-NEXT: [[_MSCMP:%.*]] = icmp ne i128 [[TMP4]], 0 -; CHECK-NEXT: [[TMP5:%.*]] = bitcast <8 x i16> [[TMP1]] to i128 -; CHECK-NEXT: [[_MSCMP1:%.*]] = icmp ne i128 [[TMP5]], 0 -; CHECK-NEXT: [[_MSOR:%.*]] = or i1 [[_MSCMP]], [[_MSCMP1]] -; CHECK-NEXT: [[TMP6:%.*]] = bitcast <8 x i16> [[TMP2]] to i128 -; CHECK-NEXT: [[_MSCMP2:%.*]] = icmp ne i128 [[TMP6]], 0 -; CHECK-NEXT: [[_MSOR3:%.*]] = or i1 [[_MSOR]], [[_MSCMP2]] -; CHECK-NEXT: br i1 [[_MSOR3]], label %[[BB7:.*]], label %[[BB8:.*]], !prof [[PROF1]] -; CHECK: [[BB7]]: -; CHECK-NEXT: call void @__msan_warning_noreturn() #[[ATTR5]] -; CHECK-NEXT: unreachable -; CHECK: [[BB8]]: +; CHECK-NEXT: [[TMP4:%.*]] = icmp ne <8 x i16> [[TMP1]], zeroinitializer +; CHECK-NEXT: [[TMP5:%.*]] = icmp ne <8 x i16> [[TMP2]], zeroinitializer +; CHECK-NEXT: [[TMP6:%.*]] = or <8 x i1> [[TMP4]], [[TMP5]] +; CHECK-NEXT: [[TMP7:%.*]] = sext <8 x i1> [[TMP6]] to <8 x i16> +; CHECK-NEXT: [[TMP8:%.*]] = bitcast <8 x i16> [[TMP7]] to <4 x i32> +; CHECK-NEXT: [[TMP19:%.*]] = icmp ne <4 x i32> [[TMP8]], zeroinitializer +; CHECK-NEXT: [[TMP20:%.*]] = sext <4 x i1> [[TMP19]] to <4 x i32> +; CHECK-NEXT: [[TMP21:%.*]] = or <4 x i32> [[TMP20]], [[TMP18]] ; CHECK-NEXT: [[TMP9:%.*]] = tail call <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128(<4 x float> [[E]], <8 x bfloat> [[A]], <8 x bfloat> [[B]]) ; CHECK-NEXT: [[TMP10:%.*]] = bitcast i4 [[TMP3]] to <4 x i1> ; CHECK-NEXT: [[TMP11:%.*]] = bitcast i4 [[U]] to <4 x i1> -; CHECK-NEXT: [[TMP12:%.*]] = select <4 x i1> [[TMP11]], <4 x i32> zeroinitializer, <4 x i32> zeroinitializer +; CHECK-NEXT: [[TMP12:%.*]] = select <4 x i1> [[TMP11]], <4 x i32> [[TMP21]], <4 x i32> zeroinitializer ; CHECK-NEXT: [[TMP13:%.*]] = bitcast <4 x float> [[TMP9]] to <4 x i32> ; CHECK-NEXT: [[TMP14:%.*]] = xor <4 x i32> [[TMP13]], zeroinitializer -; CHECK-NEXT: [[TMP15:%.*]] = or <4 x i32> [[TMP14]], zeroinitializer +; CHECK-NEXT: [[TMP15:%.*]] = or <4 x i32> [[TMP14]], [[TMP21]] ; CHECK-NEXT: [[TMP16:%.*]] = or <4 x i32> [[TMP15]], zeroinitializer ; CHECK-NEXT: [[_MSPROP_SELECT:%.*]] = select <4 x i1> [[TMP10]], <4 x i32> [[TMP16]], <4 x i32> [[TMP12]] ; CHECK-NEXT: [[TMP17:%.*]] = select <4 x i1> [[TMP11]], <4 x float> [[TMP9]], <4 x float> zeroinitializer @@ -679,32 +652,27 @@ define <4 x float> @test_mm128_mask_dpbf16ps_128(i4 zeroext %U, <4 x float> %E, ; CHECK-LABEL: define <4 x float> @test_mm128_mask_dpbf16ps_128( ; CHECK-SAME: i4 zeroext [[U:%.*]], <4 x float> [[E:%.*]], <8 x bfloat> [[A:%.*]], <8 x bfloat> [[B:%.*]]) local_unnamed_addr #[[ATTR1]] { ; CHECK-NEXT: [[ENTRY:.*:]] -; CHECK-NEXT: [[TMP0:%.*]] = load <4 x i32>, ptr getelementptr (i8, ptr @__msan_param_tls, i64 8), align 8 ; CHECK-NEXT: [[TMP1:%.*]] = load <8 x i16>, ptr getelementptr (i8, ptr @__msan_param_tls, i64 24), align 8 ; CHECK-NEXT: [[TMP2:%.*]] = load <8 x i16>, ptr getelementptr (i8, ptr @__msan_param_tls, i64 40), align 8 +; CHECK-NEXT: [[TMP0:%.*]] = load <4 x i32>, ptr getelementptr (i8, ptr @__msan_param_tls, i64 8), align 8 ; CHECK-NEXT: [[TMP3:%.*]] = load i4, ptr @__msan_param_tls, align 8 ; CHECK-NEXT: call void @llvm.donothing() -; CHECK-NEXT: [[TMP4:%.*]] = bitcast <4 x i32> [[TMP0]] to i128 -; CHECK-NEXT: [[_MSCMP:%.*]] = icmp ne i128 [[TMP4]], 0 -; CHECK-NEXT: [[TMP5:%.*]] = bitcast <8 x i16> [[TMP1]] to i128 -; CHECK-NEXT: [[_MSCMP1:%.*]] = icmp ne i128 [[TMP5]], 0 -; CHECK-NEXT: [[_MSOR:%.*]] = or i1 [[_MSCMP]], [[_MSCMP1]] -; CHECK-NEXT: [[TMP6:%.*]] = bitcast <8 x i16> [[TMP2]] to i128 -; CHECK-NEXT: [[_MSCMP2:%.*]] = icmp ne i128 [[TMP6]], 0 -; CHECK-NEXT: [[_MSOR3:%.*]] = or i1 [[_MSOR]], [[_MSCMP2]] -; CHECK-NEXT: br i1 [[_MSOR3]], label %[[BB7:.*]], label %[[BB8:.*]], !prof [[PROF1]] -; CHECK: [[BB7]]: -; CHECK-NEXT: call void @__msan_warning_noreturn() #[[ATTR5]] -; CHECK-NEXT: unreachable -; CHECK: [[BB8]]: +; CHECK-NEXT: [[TMP4:%.*]] = icmp ne <8 x i16> [[TMP1]], zeroinitializer +; CHECK-NEXT: [[TMP5:%.*]] = icmp ne <8 x i16> [[TMP2]], zeroinitializer +; CHECK-NEXT: [[TMP6:%.*]] = or <8 x i1> [[TMP4]], [[TMP5]] +; CHECK-NEXT: [[TMP7:%.*]] = sext <8 x i1> [[TMP6]] to <8 x i16> +; CHECK-NEXT: [[TMP8:%.*]] = bitcast <8 x i16> [[TMP7]] to <4 x i32> +; CHECK-NEXT: [[TMP19:%.*]] = icmp ne <4 x i32> [[TMP8]], zeroinitializer +; CHECK-NEXT: [[TMP20:%.*]] = sext <4 x i1> [[TMP19]] to <4 x i32> +; CHECK-NEXT: [[TMP21:%.*]] = or <4 x i32> [[TMP20]], [[TMP0]] ; CHECK-NEXT: [[TMP9:%.*]] = tail call <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128(<4 x float> [[E]], <8 x bfloat> [[A]], <8 x bfloat> [[B]]) ; CHECK-NEXT: [[TMP10:%.*]] = bitcast i4 [[TMP3]] to <4 x i1> ; CHECK-NEXT: [[TMP11:%.*]] = bitcast i4 [[U]] to <4 x i1> -; CHECK-NEXT: [[TMP12:%.*]] = select <4 x i1> [[TMP11]], <4 x i32> zeroinitializer, <4 x i32> [[TMP0]] +; CHECK-NEXT: [[TMP12:%.*]] = select <4 x i1> [[TMP11]], <4 x i32> [[TMP21]], <4 x i32> [[TMP0]] ; CHECK-NEXT: [[TMP13:%.*]] = bitcast <4 x float> [[TMP9]] to <4 x i32> ; CHECK-NEXT: [[TMP14:%.*]] = bitcast <4 x float> [[E]] to <4 x i32> ; CHECK-NEXT: [[TMP15:%.*]] = xor <4 x i32> [[TMP13]], [[TMP14]] -; CHECK-NEXT: [[TMP16:%.*]] = or <4 x i32> [[TMP15]], zeroinitializer +; CHECK-NEXT: [[TMP16:%.*]] = or <4 x i32> [[TMP15]], [[TMP21]] ; CHECK-NEXT: [[TMP17:%.*]] = or <4 x i32> [[TMP16]], [[TMP0]] ; CHECK-NEXT: [[_MSPROP_SELECT:%.*]] = select <4 x i1> [[TMP10]], <4 x i32> [[TMP17]], <4 x i32> [[TMP12]] ; CHECK-NEXT: [[TMP18:%.*]] = select <4 x i1> [[TMP11]], <4 x float> [[TMP9]], <4 x float> [[E]] diff --git a/llvm/test/Transforms/PGOProfile/profcheck-select.ll b/llvm/test/Transforms/PGOProfile/profcheck-select.ll index 74bcb3f52428b..e6b3ddd42fcb0 100644 --- a/llvm/test/Transforms/PGOProfile/profcheck-select.ll +++ b/llvm/test/Transforms/PGOProfile/profcheck-select.ll @@ -15,7 +15,11 @@ ; RUN: not opt -passes=prof-verify %t/verify-missing.ll 2>&1 | FileCheck %t/verify-missing.ll ; verify we can disable it. It's sufficient to see opt not failing. -; RUN: opt -passes=prof-verify -profcheck-annotate-select=0 %t/verify-missing.ll +; RUN: opt -passes=prof-verify -profcheck-annotate-select=0 --disable-output %t/verify-missing.ll + +; verify vector selects without profiles are OK. It's sufficient opt doesn't fail. +; RUN: opt -passes=prof-verify --disable-output %t/verify-vec.ll + ;--- inject.ll declare void @foo(i32 %a); @@ -24,8 +28,16 @@ define void @bar(i1 %c) { call void @foo(i32 %v) ret void } + +define <2 x i32> @vec(<2 x i1> %c, <2 x i32> %v1, <2 x i32> %v2) { + %r = select <2 x i1> %c, <2 x i32> %v1, <2 x i32> %v2 + ret <2 x i32> %r +} + ; CHECK-LABEL: @bar ; CHECK: %v = select i1 %c, i32 1, i32 2, !prof !1 +; CHECK-LABEL: @vec +; CHECK-NOT: select {{.*}} !prof ; CHECK: !0 = !{!"function_entry_count", i64 1000} ; CHECK: !1 = !{!"branch_weights", i32 2, i32 3} @@ -64,4 +76,10 @@ define void @bar(i1 %c) !prof !0 { ret void } !0 = !{!"function_entry_count", i64 1000} -; CHECK: Profile verification failed: select annotation missing \ No newline at end of file +; CHECK: Profile verification failed: select annotation missing + +;--- verify-vec.ll +define <2 x i32> @vec(<2 x i1> %c, <2 x i32> %v1, <2 x i32> %v2) !prof !{!"function_entry_count", i32 10} { + %r = select <2 x i1> %c, <2 x i32> %v1, <2 x i32> %v2 + ret <2 x i32> %r +} diff --git a/llvm/utils/profcheck-xfail.txt b/llvm/utils/profcheck-xfail.txt index 10a7b62229b8d..a8cf2b1625f8f 100644 --- a/llvm/utils/profcheck-xfail.txt +++ b/llvm/utils/profcheck-xfail.txt @@ -24,7 +24,6 @@ DebugInfo/Generic/block-asan.ll DebugInfo/X86/asan_debug_info.ll LTO/X86/diagnostic-handler-remarks-with-hotness.ll Other/optimization-remarks-auto.ll -Other/X86/debugcounter-partiallyinlinelibcalls.ll Transforms/AtomicExpand/ARM/atomic-expansion-v7.ll Transforms/AtomicExpand/SPARC/partword.ll Transforms/Attributor/align.ll @@ -94,8 +93,6 @@ Transforms/CodeGenPrepare/NVPTX/bypass-slow-div-constant-numerator.ll Transforms/CodeGenPrepare/NVPTX/bypass-slow-div.ll Transforms/CodeGenPrepare/NVPTX/bypass-slow-div-not-exact.ll Transforms/CodeGenPrepare/NVPTX/bypass-slow-div-special-cases.ll -Transforms/CodeGenPrepare/X86/vec-shift-inseltpoison.ll -Transforms/CodeGenPrepare/X86/vec-shift.ll Transforms/Coroutines/coro-await-suspend-lower-invoke.ll Transforms/Coroutines/coro-await-suspend-lower.ll Transforms/Coroutines/coro-byval-param.ll @@ -182,8 +179,6 @@ Transforms/GlobalOpt/shrink-global-to-bool-check-debug.ll Transforms/GlobalOpt/shrink-global-to-bool-opaque-ptrs.ll Transforms/GVN/debugloc-load-select.ll Transforms/GVN/load-through-select-dbg.ll -Transforms/GVN/masked-load-store.ll -Transforms/GVN/masked-load-store-no-mem-dep.ll Transforms/GVN/opaque-ptr.ll Transforms/GVN/pr69301.ll Transforms/GVN/pre-invalid-prof-metadata.ll @@ -236,9 +231,6 @@ Transforms/IndVarSimplify/pr45835.ll Transforms/IndVarSimplify/preserving-debugloc-rem-div.ll Transforms/InstCombine/2004-09-20-BadLoadCombine.ll Transforms/InstCombine/2005-04-07-UDivSelectCrash.ll -Transforms/InstCombine/AArch64/sve-intrinsic-sel.ll -Transforms/InstCombine/AArch64/sve-intrinsic-simplify-binop.ll -Transforms/InstCombine/AArch64/sve-intrinsic-simplify-shift.ll Transforms/InstCombine/add-mask.ll Transforms/InstCombine/add-shl-mul-umax.ll Transforms/InstCombine/and2.ll @@ -274,7 +266,6 @@ Transforms/InstCombine/fmul-bool.ll Transforms/InstCombine/fmul.ll Transforms/InstCombine/fneg.ll Transforms/InstCombine/fold-ctpop-of-not.ll -Transforms/InstCombine/fold-ext-eq-c-with-op.ll Transforms/InstCombine/free-inversion.ll Transforms/InstCombine/icmp-and-lowbit-mask.ll Transforms/InstCombine/icmp.ll @@ -294,8 +285,6 @@ Transforms/InstCombine/loadstore-metadata.ll Transforms/InstCombine/logical-select-inseltpoison.ll Transforms/InstCombine/logical-select.ll Transforms/InstCombine/lshr.ll -Transforms/InstCombine/masked_intrinsics-inseltpoison.ll -Transforms/InstCombine/masked_intrinsics.ll Transforms/InstCombine/memchr-11.ll Transforms/InstCombine/memchr-2.ll Transforms/InstCombine/memchr-3.ll @@ -332,14 +321,11 @@ Transforms/InstCombine/select-and-or.ll Transforms/InstCombine/select-cmp-br.ll Transforms/InstCombine/select-cmp.ll Transforms/InstCombine/select-factorize.ll -Transforms/InstCombine/select_frexp.ll Transforms/InstCombine/select.ll Transforms/InstCombine/select-min-max.ll Transforms/InstCombine/select-of-symmetric-selects.ll Transforms/InstCombine/select-select.ll Transforms/InstCombine/shift.ll -Transforms/InstCombine/shuffle-select-narrow-inseltpoison.ll -Transforms/InstCombine/shuffle-select-narrow.ll Transforms/InstCombine/simplify-demanded-fpclass.ll Transforms/InstCombine/sink-not-into-another-hand-of-logical-and.ll Transforms/InstCombine/sink-not-into-another-hand-of-logical-or.ll @@ -355,11 +341,8 @@ Transforms/InstCombine/sub-xor-cmp.ll Transforms/InstCombine/truncating-saturate.ll Transforms/InstCombine/unordered-fcmp-select.ll Transforms/InstCombine/urem-via-cmp-select.ll -Transforms/InstCombine/vec_sext.ll -Transforms/InstCombine/vector-urem.ll Transforms/InstCombine/wcslen-1.ll Transforms/InstCombine/wcslen-3.ll -Transforms/InstCombine/X86/blend_x86.ll Transforms/InstCombine/X86/x86-avx512-inseltpoison.ll Transforms/InstCombine/X86/x86-avx512.ll Transforms/InstCombine/xor-and-or.ll @@ -607,28 +590,18 @@ Transforms/OpenMP/spmdization_indirect.ll Transforms/OpenMP/spmdization.ll Transforms/OpenMP/spmdization_no_guarding_two_reaching_kernels.ll Transforms/OpenMP/spmdization_remarks.ll -Transforms/PartiallyInlineLibCalls/X86/good-prototype.ll Transforms/PGOProfile/comdat.ll Transforms/PGOProfile/memop_profile_funclet_wasm.ll Transforms/PGOProfile/X86/macho.ll Transforms/PhaseOrdering/AArch64/constraint-elimination-placement.ll Transforms/PhaseOrdering/AArch64/globals-aa-required-for-vectorization.ll -Transforms/PhaseOrdering/AArch64/hoisting-sinking-required-for-vectorization.ll -Transforms/PhaseOrdering/AArch64/predicated-reduction.ll -Transforms/PhaseOrdering/AArch64/quant_4x4.ll -Transforms/PhaseOrdering/ARM/arm_mean_q7.ll -Transforms/PhaseOrdering/vector-select.ll -Transforms/PhaseOrdering/X86/blendv-select.ll Transforms/PhaseOrdering/X86/merge-functions2.ll Transforms/PhaseOrdering/X86/merge-functions3.ll Transforms/PhaseOrdering/X86/merge-functions.ll Transforms/PhaseOrdering/X86/pr52078.ll -Transforms/PhaseOrdering/X86/pr67803.ll Transforms/PhaseOrdering/X86/preserve-access-group.ll -Transforms/PhaseOrdering/X86/vector-reductions.ll Transforms/PreISelIntrinsicLowering/AArch64/expand-exp.ll Transforms/PreISelIntrinsicLowering/AArch64/expand-log.ll -Transforms/PreISelIntrinsicLowering/expand-vp.ll Transforms/PreISelIntrinsicLowering/PowerPC/memset-pattern.ll Transforms/PreISelIntrinsicLowering/RISCV/memset-pattern.ll Transforms/PreISelIntrinsicLowering/X86/memcpy-inline-non-constant-len.ll @@ -636,7 +609,6 @@ Transforms/PreISelIntrinsicLowering/X86/memset-inline-non-constant-len.ll Transforms/PreISelIntrinsicLowering/X86/memset-pattern.ll Transforms/SampleProfile/pseudo-probe-profile-mismatch-thinlto.ll Transforms/SampleProfile/remarks-hotness.ll -Transforms/SandboxVectorizer/special_opcodes.ll Transforms/ScalarizeMaskedMemIntrin/AArch64/expand-masked-load.ll Transforms/ScalarizeMaskedMemIntrin/AArch64/expand-masked-store.ll Transforms/ScalarizeMaskedMemIntrin/AArch64/streaming-compatible-expand-masked-gather-scatter.ll @@ -675,5 +647,3 @@ Transforms/UnifyLoopExits/switch.ll Transforms/UnifyLoopExits/undef-phis.ll Transforms/Util/libcalls-opt-remarks.ll Transforms/Util/lowerswitch.ll -Transforms/VectorCombine/AArch64/shuffletoidentity.ll -Transforms/VectorCombine/X86/shuffle-of-selects.ll diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index 0965979b7c39d..19741f10ce8cc 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -596,51 +596,148 @@ def ROCDL_smfmac_f32_32x32x32_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.f //===---------------------------------------------------------------------===// // WMMA intrinsics -class ROCDL_Wmma_IntrOp overloadedOperands, - list traits = []> : - ROCDL_IntrOp, - Arguments<(ins Variadic:$args)> { - let assemblyFormat = - "$args attr-dict `:` functional-type($args, $res)"; +class ROCDL_WMMA_IntrOp : ROCDL_IntrOp, + Arguments<(ins + LLVM_ScalarOrVectorOf:$A, + LLVM_ScalarOrVectorOf:$B, + LLVM_ScalarOrVectorOf:$C)> { + let results = (outs LLVM_ScalarOrVectorOf:$res); + let assemblyFormat = [{ + $A `,` $B `,` $C attr-dict `:` functional-type(operands, $res) + }]; +} + +class ROCDL_WMMA_Opsel_IntrOp : ROCDL_IntrOp, + Arguments<(ins + LLVM_ScalarOrVectorOf:$A, + LLVM_ScalarOrVectorOf:$B, + LLVM_ScalarOrVectorOf:$C, + DefaultValuedAttr:$opsel)> { + let results = (outs LLVM_ScalarOrVectorOf:$res); + let assemblyFormat = [{ + $A `,` $B `,` $C attr-dict `:` functional-type(operands, $res) + }]; +} + +class ROCDL_WMMA_IU_IntrOp : ROCDL_IntrOp, + Arguments<(ins + DefaultValuedAttr:$signA, + LLVM_ScalarOrVectorOf:$A, + DefaultValuedAttr:$signB, + LLVM_ScalarOrVectorOf:$B, + LLVM_ScalarOrVectorOf:$C, + DefaultValuedAttr:$clamp)> { + let results = (outs LLVM_ScalarOrVectorOf:$res); + let assemblyFormat = [{ + $A `,` $B `,` $C attr-dict `:` functional-type(operands, $res) + }]; +} + +class ROCDL_WMMA_ModsAll_Reuse_IntrOp : ROCDL_IntrOp, + Arguments<(ins + DefaultValuedAttr:$signA, + LLVM_ScalarOrVectorOf:$A, + DefaultValuedAttr:$signB, + LLVM_ScalarOrVectorOf:$B, + DefaultValuedAttr:$modC, + LLVM_ScalarOrVectorOf:$C, + DefaultValuedAttr:$reuseA, + DefaultValuedAttr:$reuseB)> { + let results = (outs LLVM_ScalarOrVectorOf:$res); + let assemblyFormat = [{ + $A `,` $B `,` $C attr-dict `:` functional-type(operands, $res) + }]; +} + +class ROCDL_WMMA_ModsC_IntrOp : ROCDL_IntrOp, + Arguments<(ins + LLVM_ScalarOrVectorOf:$A, + LLVM_ScalarOrVectorOf:$B, + DefaultValuedAttr:$modC, + LLVM_ScalarOrVectorOf:$C, + DefaultValuedAttr:$reuseA, + DefaultValuedAttr:$reuseB)> { + let results = (outs LLVM_ScalarOrVectorOf:$res); + let assemblyFormat = [{ + $A `,` $B `,` $C attr-dict `:` functional-type(operands, $res) + }]; +} + +class ROCDL_WMMA_ModsAll_Diff_IntrOp : ROCDL_IntrOp, + Arguments<(ins + DefaultValuedAttr:$signA, + LLVM_ScalarOrVectorOf:$A, + DefaultValuedAttr:$signB, + LLVM_ScalarOrVectorOf:$B, + DefaultValuedAttr:$modC, + LLVM_ScalarOrVectorOf:$C, + DefaultValuedAttr:$reuseA, + DefaultValuedAttr:$reuseB)> { + let results = (outs LLVM_ScalarOrVectorOf:$res); + let assemblyFormat = [{ + $A `,` $B `,` $C attr-dict `:` functional-type(operands, $res) + }]; +} + +class ROCDL_WMMA_ModsAB_IntrOp : ROCDL_IntrOp, + Arguments<(ins + DefaultValuedAttr:$signA, + LLVM_ScalarOrVectorOf:$A, + DefaultValuedAttr:$signB, + LLVM_ScalarOrVectorOf:$B, + LLVM_ScalarOrVectorOf:$C, + DefaultValuedAttr:$reuseA, + DefaultValuedAttr:$reuseB)> { + let results = (outs LLVM_ScalarOrVectorOf:$res); + let assemblyFormat = [{ + $A `,` $B `,` $C attr-dict `:` functional-type(operands, $res) + }]; } // Available from gfx11 -def ROCDL_wmma_f32_16x16x16_f16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.f16", [0]>; -def ROCDL_wmma_f32_16x16x16_bf16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf16", [0]>; -def ROCDL_wmma_f16_16x16x16_f16 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x16.f16", [0]>; -def ROCDL_wmma_bf16_16x16x16_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16.16x16x16.bf16", [0]>; -def ROCDL_wmma_i32_16x16x16_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu8", [1]>; -def ROCDL_wmma_i32_16x16x16_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu4", [1]>; +def ROCDL_wmma_f32_16x16x16_f16 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.f16", /*Type AB=*/F16, /*Type CD=*/F32>; +def ROCDL_wmma_f32_16x16x16_bf16 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.bf16", AnyInteger, F32>; +def ROCDL_wmma_f16_16x16x16_f16 : ROCDL_WMMA_Opsel_IntrOp<"wmma.f16.16x16x16.f16", F16, F16>; +def ROCDL_wmma_bf16_16x16x16_bf16 : ROCDL_WMMA_Opsel_IntrOp<"wmma.bf16.16x16x16.bf16", AnyInteger, AnyInteger>; +def ROCDL_wmma_i32_16x16x16_iu8 : ROCDL_WMMA_IU_IntrOp<"wmma.i32.16x16x16.iu8", AnyInteger, AnyInteger>; +def ROCDL_wmma_i32_16x16x16_iu4 : ROCDL_WMMA_IU_IntrOp<"wmma.i32.16x16x16.iu4", AnyInteger, AnyInteger>; // Available from gfx12 -def ROCDL_wmma_f32_16x16x16_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_fp8", [1]>; -def ROCDL_wmma_f32_16x16x16_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_bf8", [1]>; -def ROCDL_wmma_f32_16x16x16_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_bf8", [1]>; -def ROCDL_wmma_f32_16x16x16_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_fp8", [1]>; -def ROCDL_wmma_i32_16x16x32_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x32.iu4", [1]>; +def ROCDL_wmma_f32_16x16x16_fp8_fp8 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.fp8_fp8", AnyInteger, F32>; +def ROCDL_wmma_f32_16x16x16_fp8_bf8 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.fp8_bf8", AnyInteger, F32>; +def ROCDL_wmma_f32_16x16x16_bf8_bf8 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.bf8_bf8", AnyInteger, F32>; +def ROCDL_wmma_f32_16x16x16_bf8_fp8 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.bf8_fp8", AnyInteger, F32>; +def ROCDL_wmma_i32_16x16x32_iu4 : ROCDL_WMMA_IU_IntrOp<"wmma.i32.16x16x32.iu4", AnyInteger, AnyInteger>; // Available from gfx1250 -def ROCDL_wmma_f32_16x16x4_f32 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x4.f32", [1]>; -def ROCDL_wmma_f32_16x16x32_bf16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x32.bf16", [1]>; -def ROCDL_wmma_f32_16x16x32_f16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x32.f16", [1]>; -def ROCDL_wmma_f16_16x16x32_f16 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x32.f16", [1]>; -def ROCDL_wmma_bf16_16x16x32_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16.16x16x32.bf16", [1]>; -def ROCDL_wmma_bf16f32_16x16x32_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16f32.16x16x32.bf16", [1,5]>; -def ROCDL_wmma_f32_16x16x64_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.fp8_fp8", [0]>; -def ROCDL_wmma_f32_16x16x64_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.fp8_bf8", [0]>; -def ROCDL_wmma_f32_16x16x64_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.bf8_fp8", [0]>; -def ROCDL_wmma_f32_16x16x64_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.bf8_bf8", [0]>; -def ROCDL_wmma_f16_16x16x64_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.fp8_fp8", [0]>; -def ROCDL_wmma_f16_16x16x64_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.fp8_bf8", [0]>; -def ROCDL_wmma_f16_16x16x64_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.bf8_fp8", [0]>; -def ROCDL_wmma_f16_16x16x64_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.bf8_bf8", [0]>; -def ROCDL_wmma_f32_16x16x128_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.fp8_fp8", [0]>; -def ROCDL_wmma_f32_16x16x128_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.fp8_bf8", [0]>; -def ROCDL_wmma_f32_16x16x128_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.bf8_fp8", [0]>; -def ROCDL_wmma_f32_16x16x128_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.bf8_bf8", [0]>; -def ROCDL_wmma_f16_16x16x128_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.fp8_fp8", [0]>; -def ROCDL_wmma_f16_16x16x128_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.fp8_bf8", [0]>; -def ROCDL_wmma_f16_16x16x128_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.bf8_fp8", [0]>; -def ROCDL_wmma_f16_16x16x128_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.bf8_bf8", [0]>; -def ROCDL_wmma_i32_16x16x64_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x64.iu8", [1]>; +def ROCDL_wmma_f32_16x16x4_f32 : ROCDL_WMMA_ModsAll_Reuse_IntrOp<"wmma.f32.16x16x4.f32", F32, F32>; +def ROCDL_wmma_f32_16x16x32_bf16 : ROCDL_WMMA_ModsAll_Reuse_IntrOp<"wmma.f32.16x16x32.bf16", BF16, F32>; +def ROCDL_wmma_f32_16x16x32_f16 : ROCDL_WMMA_ModsAll_Reuse_IntrOp<"wmma.f32.16x16x32.f16", F16, F32>; +def ROCDL_wmma_f16_16x16x32_f16 : ROCDL_WMMA_ModsAll_Reuse_IntrOp<"wmma.f16.16x16x32.f16", F16, F16>; +def ROCDL_wmma_bf16_16x16x32_bf16 : ROCDL_WMMA_ModsAll_Reuse_IntrOp<"wmma.bf16.16x16x32.bf16", BF16, BF16>; +def ROCDL_wmma_bf16f32_16x16x32_bf16 : ROCDL_WMMA_ModsAll_Diff_IntrOp<"wmma.bf16f32.16x16x32.bf16", BF16, /*Type C=*/F32, /*Type D=*/BF16>; +def ROCDL_wmma_f32_16x16x64_fp8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x64.fp8_fp8", AnyInteger, F32>; +def ROCDL_wmma_f32_16x16x64_fp8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x64.fp8_bf8", AnyInteger, F32>; +def ROCDL_wmma_f32_16x16x64_bf8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x64.bf8_fp8", AnyInteger, F32>; +def ROCDL_wmma_f32_16x16x64_bf8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x64.bf8_bf8", AnyInteger, F32>; +def ROCDL_wmma_f16_16x16x64_fp8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x64.fp8_fp8", AnyInteger, F16>; +def ROCDL_wmma_f16_16x16x64_fp8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x64.fp8_bf8", AnyInteger, F16>; +def ROCDL_wmma_f16_16x16x64_bf8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x64.bf8_fp8", AnyInteger, F16>; +def ROCDL_wmma_f16_16x16x64_bf8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x64.bf8_bf8", AnyInteger, F16>; +def ROCDL_wmma_f32_16x16x128_fp8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x128.fp8_fp8", AnyInteger, F32>; +def ROCDL_wmma_f32_16x16x128_fp8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x128.fp8_bf8", AnyInteger, F32>; +def ROCDL_wmma_f32_16x16x128_bf8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x128.bf8_fp8", AnyInteger, F32>; +def ROCDL_wmma_f32_16x16x128_bf8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x128.bf8_bf8", AnyInteger, F32>; +def ROCDL_wmma_f16_16x16x128_fp8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x128.fp8_fp8", AnyInteger, F16>; +def ROCDL_wmma_f16_16x16x128_fp8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x128.fp8_bf8", AnyInteger, F16>; +def ROCDL_wmma_f16_16x16x128_bf8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x128.bf8_fp8", AnyInteger, F16>; +def ROCDL_wmma_f16_16x16x128_bf8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x128.bf8_bf8", AnyInteger, F16>; +def ROCDL_wmma_i32_16x16x64_iu8 : ROCDL_WMMA_ModsAB_IntrOp<"wmma.i32.16x16x64.iu8", AnyInteger, AnyInteger>; //===---------------------------------------------------------------------===// // LDS transpose intrinsics (available in GFX950) diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 3a307a0756d93..a5831559558ac 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/TypeUtilities.h" @@ -79,12 +80,6 @@ static Value createI64Constant(ConversionPatternRewriter &rewriter, return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), value); } -static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc, - bool value) { - Type llvmI1 = rewriter.getI1Type(); - return LLVM::ConstantOp::create(rewriter, loc, llvmI1, value); -} - /// Returns the linear index used to access an element in the memref. static Value getLinearIndexI32(ConversionPatternRewriter &rewriter, Location loc, MemRefDescriptor &memRefDescriptor, @@ -684,12 +679,11 @@ static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter, /// intrinsics having been defined before the AMD backend supported bfloat. We /// similarly need to pack 8-bit float types into integers as if they were i8 /// (which they are for the backend's purposes). -static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, - Location loc, - const TypeConverter *typeConverter, - bool isUnsigned, Value llvmInput, - Value mlirInput, - SmallVector &operands) { +static void wmmaPushInputOperand( + ConversionPatternRewriter &rewriter, Location loc, + const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, + Value mlirInput, SmallVectorImpl &operands, + SmallVectorImpl &attrs, StringRef attrName) { Type inputType = llvmInput.getType(); auto vectorType = dyn_cast(inputType); if (!vectorType) { @@ -697,10 +691,6 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, return; } Type elemType = vectorType.getElementType(); - - if (elemType.isBF16()) - llvmInput = LLVM::BitcastOp::create( - rewriter, loc, vectorType.clone(rewriter.getI16Type()), llvmInput); if (elemType.getIntOrFloatBitWidth() > 8) { operands.push_back(llvmInput); return; @@ -719,8 +709,8 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, } else if (elemType.isSignedInteger()) { localIsUnsigned = false; } - Value sign = createI1Constant(rewriter, loc, !localIsUnsigned); - operands.push_back(sign); + attrs.push_back( + NamedAttribute(attrName, rewriter.getBoolAttr(!localIsUnsigned))); } int64_t numBits = @@ -751,18 +741,17 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, Value output, int32_t subwordOffset, - bool clamp, SmallVector &operands) { + bool clamp, SmallVectorImpl &operands, + SmallVectorImpl &attrs) { Type inputType = output.getType(); auto vectorType = dyn_cast(inputType); Type elemType = vectorType.getElementType(); - if (elemType.isBF16()) - output = LLVM::BitcastOp::create( - rewriter, loc, vectorType.clone(rewriter.getI16Type()), output); operands.push_back(output); if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) { - operands.push_back(createI1Constant(rewriter, loc, subwordOffset)); + attrs.push_back( + NamedAttribute("opsel", rewriter.getBoolAttr(subwordOffset))); } else if (elemType.isInteger(32)) { - operands.push_back(createI1Constant(rewriter, loc, clamp)); + attrs.push_back(NamedAttribute("clamp", rewriter.getBoolAttr(clamp))); } } @@ -1311,11 +1300,33 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern { if (chipset.majorVersion != 11 && chipset.majorVersion != 12) return op->emitOpError("WMMA only supported on gfx11 and gfx12"); - // The WMMA operations represent vectors of bf16s as vectors of i16s, so we - // need to bitcast bfloats to i16 and then bitcast them back. + bool isGFX1250 = chipset >= Chipset(12, 5, 0); + + // The WMMA operations represent vectors of bf16s as vectors of i16s + // (except on gfx1250), so we need to bitcast bfloats to i16 and then + // bitcast them back. + auto aType = cast(adaptor.getSourceA().getType()); + auto bType = cast(adaptor.getSourceB().getType()); + auto destCType = cast(adaptor.getDestC().getType()); + bool castAToI16 = aType.getElementType().isBF16() && !isGFX1250; + bool castBToI16 = bType.getElementType().isBF16() && !isGFX1250; + bool castDestCToI16 = destCType.getElementType().isBF16() && !isGFX1250; + bool castOutToI16 = outType.getElementType().isBF16() && !isGFX1250; VectorType rawOutType = outType; - if (outType.getElementType().isBF16()) + if (castOutToI16) rawOutType = outType.clone(rewriter.getI16Type()); + Value a = adaptor.getSourceA(); + if (castAToI16) + a = LLVM::BitcastOp::create(rewriter, loc, + aType.clone(rewriter.getI16Type()), a); + Value b = adaptor.getSourceB(); + if (castBToI16) + b = LLVM::BitcastOp::create(rewriter, loc, + bType.clone(rewriter.getI16Type()), b); + Value destC = adaptor.getDestC(); + if (castDestCToI16) + destC = LLVM::BitcastOp::create( + rewriter, loc, destCType.clone(rewriter.getI16Type()), destC); std::optional maybeIntrinsic = wmmaOpToIntrinsic(op, chipset); @@ -1325,18 +1336,20 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern { if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0) return op.emitOpError("subwordOffset not supported on gfx12+"); - OperationState loweredOp(loc, *maybeIntrinsic); - loweredOp.addTypes(rawOutType); - SmallVector operands; - wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(), - adaptor.getSourceA(), op.getSourceA(), operands); - wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(), - adaptor.getSourceB(), op.getSourceB(), operands); - wmmaPushOutputOperand(rewriter, loc, typeConverter, adaptor.getDestC(), - op.getSubwordOffset(), op.getClamp(), operands); + SmallVector attrs; + wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(), a, + op.getSourceA(), operands, attrs, "signA"); + wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(), b, + op.getSourceB(), operands, attrs, "signB"); + wmmaPushOutputOperand(rewriter, loc, typeConverter, destC, + op.getSubwordOffset(), op.getClamp(), operands, + attrs); + OperationState loweredOp(loc, *maybeIntrinsic); + loweredOp.addTypes(rawOutType); loweredOp.addOperands(operands); + loweredOp.addAttributes(attrs); Operation *lowered = rewriter.create(loweredOp); Operation *maybeCastBack = lowered; diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir index 9fcc1473d4a18..4e6aa17522374 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir @@ -6,30 +6,30 @@ func.func @wmma_to_rocdl(%arg0 : vector<16xf16>, %arg1 : vector<8xf32>, %arg2 : %arg6 : vector<16xi8>, %arg7 : vector<8xi32>, %arg8 : vector<4xi32>, %arg9 : vector<16xui8>, %arg10 : vector<16xi4>, %arg11 : vector<8xi4>) { // CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf32>) -> vector<8xf32> - amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xf32> + amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 {subwordOffset = 0 : i32} : vector<16xf16>, vector<16xf16>, vector<8xf32> // CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<4xf32>) -> vector<4xf32> - amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg2 : vector<16xf16>, vector<16xf16>, vector<4xf32> + amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg2 {subwordOffset = 0 : i32} : vector<16xf16>, vector<16xf16>, vector<4xf32> // CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xf32>) -> vector<8xf32> - amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg1 : vector<16xbf16>, vector<16xbf16>, vector<8xf32> + amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg1 {subwordOffset = 0 : i32} : vector<16xbf16>, vector<16xbf16>, vector<8xf32> // CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<4xf32>) -> vector<4xf32> - amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg2 : vector<16xbf16>, vector<16xbf16>, vector<4xf32> - // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16> + amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg2 {subwordOffset = 0 : i32} : vector<16xbf16>, vector<16xbf16>, vector<4xf32> + // CHECK: rocdl.wmma.f16.16x16x16.f16 {{.*}} {opsel = true} : (vector<16xf16>, vector<16xf16>, vector<16xf16>) -> vector<16xf16> amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg0 {subwordOffset = 1 : i32}: vector<16xf16>, vector<16xf16>, vector<16xf16> - // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1) -> vector<8xf16> + // CHECK: rocdl.wmma.f16.16x16x16.f16 {{.*}} : (vector<16xf16>, vector<16xf16>, vector<8xf16>) -> vector<8xf16> amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg4 {subwordOffset = 0 : i32}: vector<16xf16>, vector<16xf16>, vector<8xf16> - // CHECK: %[[raw_bf16x16:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<16xi16>, i1) -> vector<16xi16> + // CHECK: %[[raw_bf16x16:.+]] = rocdl.wmma.bf16.16x16x16.bf16 {{.*}} {opsel = true} : (vector<16xi16>, vector<16xi16>, vector<16xi16>) -> vector<16xi16> // CHECK-NEXT: llvm.bitcast %[[raw_bf16x16]] : vector<16xi16> to vector<16xbf16> amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg3 {subwordOffset = 1 : i32}: vector<16xbf16>, vector<16xbf16>, vector<16xbf16> - // CHECK: %[[raw_bf16x8:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1) -> vector<8xi16> + // CHECK: %[[raw_bf16x8:.+]] = rocdl.wmma.bf16.16x16x16.bf16 {{.*}} : (vector<16xi16>, vector<16xi16>, vector<8xi16>) -> vector<8xi16> // CHECK-NEXT: llvm.bitcast %[[raw_bf16x8]] : vector<8xi16> to vector<8xbf16> amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg5 {subwordOffset = 0 : i32}: vector<16xbf16>, vector<16xbf16>, vector<8xbf16> - // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<4xi32>, i1, vector<4xi32>, vector<8xi32>, i1) -> vector<8xi32> + // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}{clamp = true, signA = true, signB = true} : (vector<4xi32>, vector<4xi32>, vector<8xi32>) -> vector<8xi32> amdgpu.wmma 16x16x16 %arg6 * %arg6 + %arg7 {clamp}: vector<16xi8>, vector<16xi8>, vector<8xi32> - // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<4xi32>, i1, vector<4xi32>, vector<4xi32>, i1) -> vector<4xi32> + // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}{clamp = true} : (vector<4xi32>, vector<4xi32>, vector<4xi32>) -> vector<4xi32> amdgpu.wmma 16x16x16 %arg9 * %arg9 + %arg8 {unsignedA, unsignedB, clamp}: vector<16xui8>, vector<16xui8>, vector<4xi32> - // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32> + // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}{clamp = true, signA = true, signB = true} : (vector<2xi32>, vector<2xi32>, vector<8xi32>) -> vector<8xi32> amdgpu.wmma 16x16x16 %arg10 * %arg10 + %arg7 {clamp}: vector<16xi4>, vector<16xi4>, vector<8xi32> - // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32> + // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}{clamp = true, signA = true, signB = true} : (i32, i32, vector<4xi32>) -> vector<4xi32> amdgpu.wmma 16x16x16 %arg11 * %arg11 + %arg8 {clamp}: vector<8xi4>, vector<8xi4>, vector<4xi32> return diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir index 57883473bbf06..978227b4d5791 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir @@ -20,15 +20,15 @@ func.func @wmma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<4xf16>, // CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<4xi16>, vector<4xi16>, vector<4xf32>) -> vector<4xf32> amdgpu.wmma 16x16x16 %arg5 * %arg5 + %arg3 : vector<4xbf16>, vector<4xbf16>, vector<4xf32> - // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<8xf16>, vector<8xf16>, vector<8xf16>, i1) -> vector<8xf16> + // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<8xf16>, vector<8xf16>, vector<8xf16>) -> vector<8xf16> amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg0 : vector<8xf16>, vector<8xf16>, vector<8xf16> - // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<4xf16>, i1) -> vector<4xf16> + // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<4xf16>) -> vector<4xf16> amdgpu.wmma 16x16x16 %arg1 * %arg1 + %arg1 : vector<4xf16>, vector<4xf16>, vector<4xf16> - // CHECK: %[[raw_bf16x8:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<8xi16>, vector<8xi16>, vector<8xi16>, i1) -> vector<8xi16> + // CHECK: %[[raw_bf16x8:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<8xi16>, vector<8xi16>, vector<8xi16>) -> vector<8xi16> // CHECK-NEXT: llvm.bitcast %[[raw_bf16x8]] : vector<8xi16> to vector<8xbf16> amdgpu.wmma 16x16x16 %arg4 * %arg4 + %arg4 : vector<8xbf16>, vector<8xbf16>, vector<8xbf16> - // CHECK: rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<4xi16>, vector<4xi16>, vector<4xi16>, i1) -> vector<4xi16> + // CHECK: rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<4xi16>, vector<4xi16>, vector<4xi16>) -> vector<4xi16> amdgpu.wmma 16x16x16 %arg5 * %arg5 + %arg5 : vector<4xbf16>, vector<4xbf16>, vector<4xbf16> // CHECK: rocdl.wmma.f32.16x16x16.fp8_fp8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32> @@ -51,19 +51,19 @@ func.func @wmma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<4xf16>, // CHECK: rocdl.wmma.f32.16x16x16.bf8_fp8{{.*}}: (i32, i32, vector<4xf32>) -> vector<4xf32> amdgpu.wmma 16x16x16 %arg9 * %arg7 + %arg3 : vector<4xf8E5M2>, vector<4xf8E4M3FN>, vector<4xf32> - // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32> + // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}{clamp = true, signA = true, signB = true} : (vector<2xi32>, vector<2xi32>, vector<8xi32>) -> vector<8xi32> amdgpu.wmma 16x16x16 %arg10 * %arg10 + %arg12 {clamp} : vector<8xi8>, vector<8xi8>, vector<8xi32> - // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32> + // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}{clamp = true} : (i32, i32, vector<4xi32>) -> vector<4xi32> amdgpu.wmma 16x16x16 %arg11 * %arg11 + %arg13 {unsignedA, unsignedB, clamp}: vector<4xi8>, vector<4xi8>, vector<4xi32> - // CHECK: rocdl.wmma.i32.16x16x32.iu4{{.*}}: (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32> + // CHECK: rocdl.wmma.i32.16x16x32.iu4{{.*}}{clamp = true, signA = true, signB = true} : (vector<2xi32>, vector<2xi32>, vector<8xi32>) -> vector<8xi32> amdgpu.wmma 16x16x32 %arg14 * %arg14 + %arg12 {clamp} : vector<16xi4>, vector<16xi4>, vector<8xi32> - // CHECK: rocdl.wmma.i32.16x16x32.iu4{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32> + // CHECK: rocdl.wmma.i32.16x16x32.iu4{{.*}}{clamp = true, signA = true, signB = true} : (i32, i32, vector<4xi32>) -> vector<4xi32> amdgpu.wmma 16x16x32 %arg15 * %arg15 + %arg13 {clamp} : vector<8xi4>, vector<8xi4>, vector<4xi32> - // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, i32, i1, i32, vector<8xi32>, i1) -> vector<8xi32> + // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}{clamp = true, signA = true, signB = true} : (i32, i32, vector<8xi32>) -> vector<8xi32> amdgpu.wmma 16x16x16 %arg15 * %arg15 + %arg12 {clamp} : vector<8xi4>, vector<8xi4>, vector<8xi32> - // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32> + // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}{clamp = true, signA = true, signB = true} : (i32, i32, vector<4xi32>) -> vector<4xi32> amdgpu.wmma 16x16x16 %arg16 * %arg16 + %arg13 {clamp} : vector<4xi4>, vector<4xi4>, vector<4xi32> func.return diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir index 5e77a3add3184..37259f6ed06eb 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir @@ -14,13 +14,13 @@ func.func @wmma_k32(%arg0 : vector<16xf16>, %arg1 : vector<16xbf16>, %arg2 : vec // CHECK: rocdl.wmma.f32.16x16x32.f16 %arg0, %arg0, %arg2 amdgpu.wmma 16x16x32 %arg0 * %arg0 + %arg2 : vector<16xf16>, vector<16xf16>, vector<8xf32> - // CHECK: rocdl.wmma.f16.16x16x32.f16 %arg0, %arg0, {{.*}} : (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1) + // CHECK: rocdl.wmma.f16.16x16x32.f16 %arg0, %arg0, {{.*}} : (vector<16xf16>, vector<16xf16>, vector<8xf16>) amdgpu.wmma 16x16x32 %arg0 * %arg0 + %arg3 : vector<16xf16>, vector<16xf16>, vector<8xf16> // CHECK: rocdl.wmma.f32.16x16x32.bf16 {{.*}}, {{.*}}, %arg2 amdgpu.wmma 16x16x32 %arg1 * %arg1 + %arg2 : vector<16xbf16>, vector<16xbf16>, vector<8xf32> - // CHECK: rocdl.wmma.bf16.16x16x32.bf16 {{.*}}, {{.*}}, {{.*}}, {{.*}} : (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1) + // CHECK: rocdl.wmma.bf16.16x16x32.bf16 {{.*}}, {{.*}}, {{.*}} : (vector<16xbf16>, vector<16xbf16>, vector<8xbf16>) amdgpu.wmma 16x16x32 %arg1 * %arg1 + %arg4 : vector<16xbf16>, vector<16xbf16>, vector<8xbf16> return @@ -29,31 +29,31 @@ func.func @wmma_k32(%arg0 : vector<16xf16>, %arg1 : vector<16xbf16>, %arg2 : vec // CHECK-LABEL: @wmma_k64 func.func @wmma_k64(%arg0 : vector<32xi8>, %arg1 : vector<32xf8E4M3FN>, %arg2 : vector<32xf8E5M2>, %arg3 : vector<8xi32>, %arg4 : vector<8xf32>, %arg5 : vector<8xf16>) { - // CHECK: rocdl.wmma.i32.16x16x64.iu8 {{.*}}, {{.*}}, {{.*}}, {{.*}}, %arg3, {{.*}} + // CHECK: rocdl.wmma.i32.16x16x64.iu8 {{.*}}, {{.*}}, %arg3 {clamp = true, signA = true, signB = true} amdgpu.wmma 16x16x64 %arg0 * %arg0 + %arg3 {clamp} : vector<32xi8>, vector<32xi8>, vector<8xi32> // CHECK: rocdl.wmma.f32.16x16x64.fp8_fp8 {{.*}}, {{.*}}, %arg4 amdgpu.wmma 16x16x64 %arg1 * %arg1 + %arg4 : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<8xf32> - // CHECK: rocdl.wmma.f16.16x16x64.fp8_fp8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1) + // CHECK: rocdl.wmma.f16.16x16x64.fp8_fp8 {{.*}}, {{.*}}, %arg5 {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>) amdgpu.wmma 16x16x64 %arg1 * %arg1 + %arg5 : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<8xf16> // CHECK: rocdl.wmma.f32.16x16x64.fp8_bf8 {{.*}}, {{.*}}, %arg4 amdgpu.wmma 16x16x64 %arg1 * %arg2 + %arg4 : vector<32xf8E4M3FN>, vector<32xf8E5M2>, vector<8xf32> - // CHECK: rocdl.wmma.f16.16x16x64.fp8_bf8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1) + // CHECK: rocdl.wmma.f16.16x16x64.fp8_bf8 {{.*}}, {{.*}}, %arg5 {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>) amdgpu.wmma 16x16x64 %arg1 * %arg2 + %arg5 : vector<32xf8E4M3FN>, vector<32xf8E5M2>, vector<8xf16> // CHECK: rocdl.wmma.f32.16x16x64.bf8_bf8 {{.*}}, {{.*}}, %arg4 amdgpu.wmma 16x16x64 %arg2 * %arg2 + %arg4 : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<8xf32> - // CHECK: rocdl.wmma.f16.16x16x64.bf8_bf8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1) + // CHECK: rocdl.wmma.f16.16x16x64.bf8_bf8 {{.*}}, {{.*}}, %arg5 {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>) amdgpu.wmma 16x16x64 %arg2 * %arg2 + %arg5 : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<8xf16> // CHECK: rocdl.wmma.f32.16x16x64.bf8_fp8 {{.*}}, {{.*}}, %arg4 amdgpu.wmma 16x16x64 %arg2 * %arg1 + %arg4 : vector<32xf8E5M2>, vector<32xf8E4M3FN>, vector<8xf32> - // CHECK: rocdl.wmma.f16.16x16x64.bf8_fp8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1) + // CHECK: rocdl.wmma.f16.16x16x64.bf8_fp8 {{.*}}, {{.*}}, %arg5 {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>) amdgpu.wmma 16x16x64 %arg2 * %arg1 + %arg5 : vector<32xf8E5M2>, vector<32xf8E4M3FN>, vector<8xf16> return @@ -65,25 +65,25 @@ func.func @wmma_k128(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf8E5M2>, // CHECK: rocdl.wmma.f32.16x16x128.fp8_fp8 {{.*}}, {{.*}}, %arg2 amdgpu.wmma 16x16x128 %arg0 * %arg0 + %arg2 : vector<64xf8E4M3FN>, vector<64xf8E4M3FN>, vector<8xf32> - // CHECK: rocdl.wmma.f16.16x16x128.fp8_fp8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1) + // CHECK: rocdl.wmma.f16.16x16x128.fp8_fp8 {{.*}}, {{.*}}, %arg3 {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>) amdgpu.wmma 16x16x128 %arg0 * %arg0 + %arg3 : vector<64xf8E4M3FN>, vector<64xf8E4M3FN>, vector<8xf16> // CHECK: rocdl.wmma.f32.16x16x128.fp8_bf8 {{.*}}, {{.*}}, %arg2 amdgpu.wmma 16x16x128 %arg0 * %arg1 + %arg2 : vector<64xf8E4M3FN>, vector<64xf8E5M2>, vector<8xf32> - // CHECK: rocdl.wmma.f16.16x16x128.fp8_bf8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1) + // CHECK: rocdl.wmma.f16.16x16x128.fp8_bf8 {{.*}}, {{.*}}, %arg3 {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>) amdgpu.wmma 16x16x128 %arg0 * %arg1 + %arg3 : vector<64xf8E4M3FN>, vector<64xf8E5M2>, vector<8xf16> // CHECK: rocdl.wmma.f32.16x16x128.bf8_bf8 {{.*}}, {{.*}}, %arg2 amdgpu.wmma 16x16x128 %arg1 * %arg1 + %arg2 : vector<64xf8E5M2>, vector<64xf8E5M2>, vector<8xf32> - // CHECK: rocdl.wmma.f16.16x16x128.bf8_bf8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1) + // CHECK: rocdl.wmma.f16.16x16x128.bf8_bf8 {{.*}}, {{.*}}, %arg3 {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>) amdgpu.wmma 16x16x128 %arg1 * %arg1 + %arg3 : vector<64xf8E5M2>, vector<64xf8E5M2>, vector<8xf16> // CHECK: rocdl.wmma.f32.16x16x128.bf8_fp8 {{.*}}, {{.*}}, %arg2 amdgpu.wmma 16x16x128 %arg1 * %arg0 + %arg2 : vector<64xf8E5M2>, vector<64xf8E4M3FN>, vector<8xf32> - // CHECK: rocdl.wmma.f16.16x16x128.bf8_fp8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1) + // CHECK: rocdl.wmma.f16.16x16x128.bf8_fp8 {{.*}}, {{.*}}, %arg3 {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>) amdgpu.wmma 16x16x128 %arg1 * %arg0 + %arg3 : vector<64xf8E5M2>, vector<64xf8E4M3FN>, vector<8xf16> return diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir index 2922665295cf3..dcf80ad4395de 100644 --- a/mlir/test/Target/LLVMIR/rocdl.mlir +++ b/mlir/test/Target/LLVMIR/rocdl.mlir @@ -888,140 +888,138 @@ llvm.func @rocdl.wmma(%arg0 : vector<8xf32>, %arg1 : vector<16 x f16>, %arg2 : v %arg4 : vector<2xi32>, %arg5 : vector<4xi32>, %arg6 : vector<4xf32>, %arg7 : vector<8xf16>, %arg8 : vector<8xi16>, %arg9 : vector<32xf16>, %arg10 : vector<16xf32>, %arg11 : vector<4xf32>, %arg12 : vector<32xf32>, %arg13 : vector<64xf32>, %arg14 : vector<64xi32>, %arg15 : vector<64xf16>, %arg16 : vector<16xbf16>, %arg17 : vector<32xbf16>) -> vector<8xf32> { - %zero = llvm.mlir.constant(false) : i1 - %zero_i16 = llvm.mlir.constant(0 : i16) : i16 - // ---- Wave32 ----- + // ---- Wave32 ----- // f16 -> f32 - // CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.f16.v8f32.v16f16(<16 x half> %{{.*}}, <16 x half> %{{.*}}, <8 x float> %{{.*}}) + // CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.f16.v8f32.v16f16(<16 x half> %{{.*}} <16 x half> %{{.*}} <8 x float> %{{.*}}) %r0 = rocdl.wmma.f32.16x16x16.f16 %arg1, %arg1, %arg0 : (vector<16xf16>, vector<16xf16>, vector<8xf32>) -> vector<8xf32> // bf16 -> f32 - // CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.bf16.v8f32.v16i16(<16 x i16> %{{.*}}, <16 x i16> %{{.*}}, <8 x float> %{{.*}}) + // CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.bf16.v8f32.v16i16(<16 x i16> %{{.*}} <16 x i16> %{{.*}} <8 x float> %{{.*}}) %r1 = rocdl.wmma.f32.16x16x16.bf16 %arg2, %arg2, %arg0 : (vector<16xi16>, vector<16xi16>, vector<8xf32>) -> vector<8xf32> // f16 -> f16 (OPSEL = {0,1}) - // CHECK: call <16 x half> @llvm.amdgcn.wmma.f16.16x16x16.f16.v16f16.v16f16(<16 x half> %{{.*}}, <16 x half> %{{.*}}, <16 x half> %{{.*}}, i1 {{.*}}) - %r2 = rocdl.wmma.f16.16x16x16.f16 %arg1, %arg1, %arg1, %zero : (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16> + // CHECK: call <16 x half> @llvm.amdgcn.wmma.f16.16x16x16.f16.v16f16.v16f16(<16 x half> %{{.*}} <16 x half> %{{.*}} <16 x half> %{{.*}} i1 false) + %r2 = rocdl.wmma.f16.16x16x16.f16 %arg1, %arg1, %arg1 {opsel = false} : (vector<16xf16>, vector<16xf16>, vector<16xf16>) -> vector<16xf16> // bf16 -> bf16 (OPSEL = {0,1}) - // CHECK: call <16 x i16> @llvm.amdgcn.wmma.bf16.16x16x16.bf16.v16i16.v16i16(<16 x i16> %{{.*}}, <16 x i16> %{{.*}}, <16 x i16> %{{.*}}, i1 {{.*}}) - %r4 = rocdl.wmma.bf16.16x16x16.bf16 %arg2, %arg2, %arg2, %zero : (vector<16xi16>, vector<16xi16>, vector<16xi16>, i1) -> vector<16xi16> + // CHECK: call <16 x i16> @llvm.amdgcn.wmma.bf16.16x16x16.bf16.v16i16.v16i16(<16 x i16> %{{.*}} <16 x i16> %{{.*}} <16 x i16> %{{.*}} i1 false) + %r4 = rocdl.wmma.bf16.16x16x16.bf16 %arg2, %arg2, %arg2 {opsel = false} : (vector<16xi16>, vector<16xi16>, vector<16xi16>) -> vector<16xi16> // int8 -> int32 (signA = {0,1}, signB = {0,1}, clamp = {0,1}) - // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu8.v8i32.v4i32(i1 {{.*}}, <4 x i32> %{{.*}}, i1 {{.*}}, <4 x i32> %{{.*}}, <8 x i32> %{{.*}}, i1 {{.*}}) - %r5 = rocdl.wmma.i32.16x16x16.iu8 %zero, %arg5, %zero, %arg5, %arg3, %zero : (i1, vector<4xi32>, i1, vector<4xi32>, vector<8xi32>, i1) -> vector<8xi32> + // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu8.v8i32.v4i32(i1 false, <4 x i32> %{{.*}} i1 false, <4 x i32> %{{.*}} <8 x i32> %{{.*}} i1 false) + %r5 = rocdl.wmma.i32.16x16x16.iu8 %arg5, %arg5, %arg3 {signA = false, signB = false, clamp = false} : (vector<4xi32>, vector<4xi32>, vector<8xi32>) -> vector<8xi32> // int4 -> int32 (signA = {0,1}, signB = {0,1}, clamp = {0,1}) - // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu4.v8i32.v2i32(i1 {{.*}}, <2 x i32> %{{.*}}, i1 {{.*}}, <2 x i32> %{{.*}}, <8 x i32> %{{.*}}, i1 {{.*}}) - %r6 = rocdl.wmma.i32.16x16x16.iu4 %zero, %arg4, %zero, %arg4, %arg3, %zero : (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32> + // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu4.v8i32.v2i32(i1 false, <2 x i32> %{{.*}} i1 false, <2 x i32> %{{.*}} <8 x i32> %{{.*}} i1 false) + %r6 = rocdl.wmma.i32.16x16x16.iu4 %arg4, %arg4, %arg3 {signA = false, signB = false, clamp = false} : (vector<2xi32>, vector<2xi32>, vector<8xi32>) -> vector<8xi32> // int4 -> int32 (signA = {0,1}, signB = {0,1}, clamp = {0,1}) - // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x32.iu4.v8i32.v2i32(i1 {{.*}}, <2 x i32> %{{.*}}, i1 {{.*}}, <2 x i32> %{{.*}}, <8 x i32> %{{.*}}, i1 {{.*}}) - %r6.gfx12 = rocdl.wmma.i32.16x16x32.iu4 %zero, %arg4, %zero, %arg4, %arg3, %zero : (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32> + // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x32.iu4.v8i32.v2i32(i1 false, <2 x i32> %{{.*}} i1 false, <2 x i32> %{{.*}} <8 x i32> %{{.*}} i1 false) + %r6.gfx12 = rocdl.wmma.i32.16x16x32.iu4 %arg4, %arg4, %arg3 {signA = false, signB = false, clamp = false} : (vector<2xi32>, vector<2xi32>, vector<8xi32>) -> vector<8xi32> // f32 -> f32 - // CHECK: call <4 x float> @llvm.amdgcn.wmma.f32.16x16x4.f32.v4f32.v16f32(i1 {{.*}}, <16 x float> %{{.*}}, i1 {{.*}}, <16 x float> %{{.*}}, i16 0, <4 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r1.gfx1250 = rocdl.wmma.f32.16x16x4.f32 %zero, %arg10, %zero, %arg10, %zero_i16, %arg11, %zero, %zero : (i1, vector<16xf32>, i1, vector<16xf32>, i16, vector<4xf32>, i1, i1) -> vector<4xf32> + // CHECK: call <4 x float> @llvm.amdgcn.wmma.f32.16x16x4.f32.v4f32.v16f32(i1 false, <16 x float> %{{.*}} i1 false, <16 x float> %{{.*}} i16 0, <4 x float> %{{.*}} i1 false, i1 false) + %r1.gfx1250 = rocdl.wmma.f32.16x16x4.f32 %arg10, %arg10, %arg11 {signA = false, signB = false, modC = 0 : i16} : (vector<16xf32>, vector<16xf32>, vector<4xf32>) -> vector<4xf32> // f16 -> f32 - // CHECK: call <32 x float> @llvm.amdgcn.wmma.f32.16x16x32.f16.v32f32.v16f16(i1 {{.*}}, <16 x half> %{{.*}}, i1 {{.*}}, <16 x half> %{{.*}}, i16 0, <32 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r2.gfx1250 = rocdl.wmma.f32.16x16x32.f16 %zero, %arg1, %zero, %arg1, %zero_i16, %arg12, %zero, %zero : (i1, vector<16xf16>, i1, vector<16xf16>, i16, vector<32xf32>, i1, i1) -> vector<32xf32> + // CHECK: call <32 x float> @llvm.amdgcn.wmma.f32.16x16x32.f16.v32f32.v16f16(i1 false, <16 x half> %{{.*}} i1 false, <16 x half> %{{.*}} i16 0, <32 x float> %{{.*}} i1 false, i1 false) + %r2.gfx1250 = rocdl.wmma.f32.16x16x32.f16 %arg1, %arg1, %arg12 {signA = false, signB = false, modC = 0 : i16} : (vector<16xf16>, vector<16xf16>, vector<32xf32>) -> vector<32xf32> // bf16 -> f32 - // CHECK: call <32 x float> @llvm.amdgcn.wmma.f32.16x16x32.bf16.v32f32.v16bf16(i1 {{.*}}, <16 x bfloat> %{{.*}}, i1 {{.*}}, <16 x bfloat> %{{.*}}, i16 0, <32 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r3.gfx1250 = rocdl.wmma.f32.16x16x32.bf16 %zero, %arg16, %zero, %arg16, %zero_i16, %arg12, %zero, %zero : (i1, vector<16xbf16>, i1, vector<16xbf16>, i16, vector<32xf32>, i1, i1) -> vector<32xf32> + // CHECK: call <32 x float> @llvm.amdgcn.wmma.f32.16x16x32.bf16.v32f32.v16bf16(i1 false, <16 x bfloat> %{{.*}} i1 false, <16 x bfloat> %{{.*}} i16 0, <32 x float> %{{.*}} i1 false, i1 false) + %r3.gfx1250 = rocdl.wmma.f32.16x16x32.bf16 %arg16, %arg16, %arg12 {signA = false, signB = false, modC = 0 : i16} : (vector<16xbf16>, vector<16xbf16>, vector<32xf32>) -> vector<32xf32> // f16 -> f16 - // CHECK: call <32 x half> @llvm.amdgcn.wmma.f16.16x16x32.f16.v32f16.v16f16(i1 {{.*}}, <16 x half> %{{.*}}, i1 {{.*}}, <16 x half> %{{.*}}, i16 0, <32 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r4.gfx1250 = rocdl.wmma.f16.16x16x32.f16 %zero, %arg1, %zero, %arg1, %zero_i16, %arg9, %zero, %zero : (i1, vector<16xf16>, i1, vector<16xf16>, i16, vector<32xf16>, i1, i1) -> vector<32xf16> + // CHECK: call <32 x half> @llvm.amdgcn.wmma.f16.16x16x32.f16.v32f16.v16f16(i1 false, <16 x half> %{{.*}} i1 false, <16 x half> %{{.*}} i16 0, <32 x half> %{{.*}} i1 false, i1 false) + %r4.gfx1250 = rocdl.wmma.f16.16x16x32.f16 %arg1, %arg1, %arg9 {signA = false, signB = false, modC = 0 : i16} : (vector<16xf16>, vector<16xf16>, vector<32xf16>) -> vector<32xf16> // bf16 -> bf16 - // CHECK: call <32 x bfloat> @llvm.amdgcn.wmma.bf16.16x16x32.bf16.v32bf16.v16bf16(i1 {{.*}}, <16 x bfloat> %{{.*}}, i1 {{.*}}, <16 x bfloat> %{{.*}}, i16 0, <32 x bfloat> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r5.gfx1250 = rocdl.wmma.bf16.16x16x32.bf16 %zero, %arg16, %zero, %arg16, %zero_i16, %arg17, %zero, %zero : (i1, vector<16xbf16>, i1, vector<16xbf16>, i16, vector<32xbf16>, i1, i1) -> vector<32xbf16> + // CHECK: call <32 x bfloat> @llvm.amdgcn.wmma.bf16.16x16x32.bf16.v32bf16.v16bf16(i1 false, <16 x bfloat> %{{.*}} i1 false, <16 x bfloat> %{{.*}} i16 0, <32 x bfloat> %{{.*}} i1 false, i1 false) + %r5.gfx1250 = rocdl.wmma.bf16.16x16x32.bf16 %arg16, %arg16, %arg17 {signA = false, signB = false, modC = 0 : i16} : (vector<16xbf16>, vector<16xbf16>, vector<32xbf16>) -> vector<32xbf16> // bf16 -> bf16 / f32 - // CHECK: call <32 x bfloat> @llvm.amdgcn.wmma.bf16f32.16x16x32.bf16.v32bf16.v16bf16.v32f32(i1 {{.*}}, <16 x bfloat> %{{.*}}, i1 {{.*}}, <16 x bfloat> %{{.*}}, i16 0, <32 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r6.gfx1250 = rocdl.wmma.bf16f32.16x16x32.bf16 %zero, %arg16, %zero, %arg16, %zero_i16, %arg12, %zero, %zero : (i1, vector<16xbf16>, i1, vector<16xbf16>, i16, vector<32xf32>, i1, i1) -> vector<32xbf16> + // CHECK: call <32 x bfloat> @llvm.amdgcn.wmma.bf16f32.16x16x32.bf16.v32bf16.v16bf16.v32f32(i1 false, <16 x bfloat> %{{.*}} i1 false, <16 x bfloat> %{{.*}} i16 0, <32 x float> %{{.*}} i1 false, i1 false) + %r6.gfx1250 = rocdl.wmma.bf16f32.16x16x32.bf16 %arg16, %arg16, %arg12 {signA = false, signB = false, modC = 0 : i16} : (vector<16xbf16>, vector<16xbf16>, vector<32xf32>) -> vector<32xbf16> // f8/bf8 -> f16/f32 - // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.fp8.fp8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r7.gfx1250 = rocdl.wmma.f32.16x16x64.fp8_fp8 %arg5, %arg5, %zero_i16, %arg13, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf32>, i1, i1) -> vector<64xf32> - - // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.fp8.bf8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r8.gfx1250 = rocdl.wmma.f32.16x16x64.fp8_bf8 %arg5, %arg5, %zero_i16, %arg13, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf32>, i1, i1) -> vector<64xf32> + // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.fp8.fp8.v64f32.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x float> %{{.*}} i1 false, i1 false) + %r7.gfx1250 = rocdl.wmma.f32.16x16x64.fp8_fp8 %arg5, %arg5, %arg13 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf32>) -> vector<64xf32> - // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.bf8.fp8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r9.gfx1250 = rocdl.wmma.f32.16x16x64.bf8_fp8 %arg5, %arg5, %zero_i16, %arg13, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf32>, i1, i1) -> vector<64xf32> + // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.fp8.bf8.v64f32.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x float> %{{.*}} i1 false, i1 false) + %r8.gfx1250 = rocdl.wmma.f32.16x16x64.fp8_bf8 %arg5, %arg5, %arg13 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf32>) -> vector<64xf32> - // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.bf8.bf8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r10.gfx1250 = rocdl.wmma.f32.16x16x64.bf8_bf8 %arg5, %arg5, %zero_i16, %arg13, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf32>, i1, i1) -> vector<64xf32> + // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.bf8.fp8.v64f32.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x float> %{{.*}} i1 false, i1 false) + %r9.gfx1250 = rocdl.wmma.f32.16x16x64.bf8_fp8 %arg5, %arg5, %arg13 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf32>) -> vector<64xf32> - // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x64.fp8.fp8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r11.gfx1250 = rocdl.wmma.f16.16x16x64.fp8_fp8 %arg5, %arg5, %zero_i16, %arg15, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf16>, i1, i1) -> vector<64xf16> + // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.bf8.bf8.v64f32.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x float> %{{.*}} i1 false, i1 false) + %r10.gfx1250 = rocdl.wmma.f32.16x16x64.bf8_bf8 %arg5, %arg5, %arg13 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf32>) -> vector<64xf32> - // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x64.fp8.bf8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r12.gfx1250 = rocdl.wmma.f16.16x16x64.fp8_bf8 %arg5, %arg5, %zero_i16, %arg15, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf16>, i1, i1) -> vector<64xf16> + // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x64.fp8.fp8.v64f16.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x half> %{{.*}} i1 false, i1 false) + %r11.gfx1250 = rocdl.wmma.f16.16x16x64.fp8_fp8 %arg5, %arg5, %arg15 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf16>) -> vector<64xf16> - // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x64.bf8.fp8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r13.gfx1250 = rocdl.wmma.f16.16x16x64.bf8_fp8 %arg5, %arg5, %zero_i16, %arg15, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf16>, i1, i1) -> vector<64xf16> + // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x64.fp8.bf8.v64f16.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x half> %{{.*}} i1 false, i1 false) + %r12.gfx1250 = rocdl.wmma.f16.16x16x64.fp8_bf8 %arg5, %arg5, %arg15 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf16>) -> vector<64xf16> - // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x64.bf8.bf8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r14.gfx1250 = rocdl.wmma.f16.16x16x64.bf8_bf8 %arg5, %arg5, %zero_i16, %arg15, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf16>, i1, i1) -> vector<64xf16> + // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x64.bf8.fp8.v64f16.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x half> %{{.*}} i1 false, i1 false) + %r13.gfx1250 = rocdl.wmma.f16.16x16x64.bf8_fp8 %arg5, %arg5, %arg15 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf16>) -> vector<64xf16> - // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x128.fp8.fp8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r15.gfx1250 = rocdl.wmma.f32.16x16x128.fp8_fp8 %arg5, %arg5, %zero_i16, %arg13, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf32>, i1, i1) -> vector<64xf32> + // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x64.bf8.bf8.v64f16.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x half> %{{.*}} i1 false, i1 false) + %r14.gfx1250 = rocdl.wmma.f16.16x16x64.bf8_bf8 %arg5, %arg5, %arg15 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf16>) -> vector<64xf16> - // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x128.fp8.bf8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r16.gfx1250 = rocdl.wmma.f32.16x16x128.fp8_bf8 %arg5, %arg5, %zero_i16, %arg13, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf32>, i1, i1) -> vector<64xf32> + // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x128.fp8.fp8.v64f32.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x float> %{{.*}} i1 false, i1 false) + %r15.gfx1250 = rocdl.wmma.f32.16x16x128.fp8_fp8 %arg5, %arg5, %arg13 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf32>) -> vector<64xf32> - // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x128.bf8.fp8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r17.gfx1250 = rocdl.wmma.f32.16x16x128.bf8_fp8 %arg5, %arg5, %zero_i16, %arg13, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf32>, i1, i1) -> vector<64xf32> + // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x128.fp8.bf8.v64f32.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x float> %{{.*}} i1 false, i1 false) + %r16.gfx1250 = rocdl.wmma.f32.16x16x128.fp8_bf8 %arg5, %arg5, %arg13 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf32>) -> vector<64xf32> - // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x128.bf8.bf8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r18.gfx1250 = rocdl.wmma.f32.16x16x128.bf8_bf8 %arg5, %arg5, %zero_i16, %arg13, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf32>, i1, i1) -> vector<64xf32> + // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x128.bf8.fp8.v64f32.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x float> %{{.*}} i1 false, i1 false) + %r17.gfx1250 = rocdl.wmma.f32.16x16x128.bf8_fp8 %arg5, %arg5, %arg13 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf32>) -> vector<64xf32> - // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x128.fp8.fp8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r19.gfx1250 = rocdl.wmma.f16.16x16x128.fp8_fp8 %arg5, %arg5, %zero_i16, %arg15, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf16>, i1, i1) -> vector<64xf16> + // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x128.bf8.bf8.v64f32.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x float> %{{.*}} i1 false, i1 false) + %r18.gfx1250 = rocdl.wmma.f32.16x16x128.bf8_bf8 %arg5, %arg5, %arg13 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf32>) -> vector<64xf32> - // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x128.fp8.bf8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r20.gfx1250 = rocdl.wmma.f16.16x16x128.fp8_bf8 %arg5, %arg5, %zero_i16, %arg15, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf16>, i1, i1) -> vector<64xf16> + // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x128.fp8.fp8.v64f16.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x half> %{{.*}} i1 false, i1 false) + %r19.gfx1250 = rocdl.wmma.f16.16x16x128.fp8_fp8 %arg5, %arg5, %arg15 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf16>) -> vector<64xf16> + + // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x128.fp8.bf8.v64f16.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x half> %{{.*}} i1 false, i1 false) + %r20.gfx1250 = rocdl.wmma.f16.16x16x128.fp8_bf8 %arg5, %arg5, %arg15 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf16>) -> vector<64xf16> - // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x128.bf8.fp8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r21.gfx1250 = rocdl.wmma.f16.16x16x128.bf8_fp8 %arg5, %arg5, %zero_i16, %arg15, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf16>, i1, i1) -> vector<64xf16> + // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x128.bf8.fp8.v64f16.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x half> %{{.*}} i1 false, i1 false) + %r21.gfx1250 = rocdl.wmma.f16.16x16x128.bf8_fp8 %arg5, %arg5, %arg15 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf16>) -> vector<64xf16> - // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x128.bf8.bf8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r22.gfx1250 = rocdl.wmma.f16.16x16x128.bf8_bf8 %arg5, %arg5, %zero_i16, %arg15, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf16>, i1, i1) -> vector<64xf16> + // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x128.bf8.bf8.v64f16.v4i32(<4 x i32> %{{.*}} <4 x i32> %{{.*}} i16 0, <64 x half> %{{.*}} i1 false, i1 false) + %r22.gfx1250 = rocdl.wmma.f16.16x16x128.bf8_bf8 %arg5, %arg5, %arg15 {signA = false, signB = false, modC = 0 : i16} : (vector<4xi32>, vector<4xi32>, vector<64xf16>) -> vector<64xf16> // iu8 -> i32 - // CHECK: call <64 x i32> @llvm.amdgcn.wmma.i32.16x16x64.iu8.v64i32.v4i32(i1 {{.*}}, <4 x i32> %{{.*}}, i1 {{.*}}, <4 x i32> %{{.*}}, <64 x i32> %{{.*}}, i1 {{.*}}, i1 {{.*}}) - %r23.gfx1250 = rocdl.wmma.i32.16x16x64.iu8 %zero, %arg5, %zero, %arg5, %arg14, %zero, %zero : (i1, vector<4xi32>, i1, vector<4xi32>, vector<64xi32>, i1, i1) -> vector<64xi32> + // CHECK: call <64 x i32> @llvm.amdgcn.wmma.i32.16x16x64.iu8.v64i32.v4i32(i1 false, <4 x i32> %{{.*}} i1 false, <4 x i32> %{{.*}} <64 x i32> %{{.*}} i1 false, i1 false) + %r23.gfx1250 = rocdl.wmma.i32.16x16x64.iu8 %arg5, %arg5, %arg14 {signA = false, signB = false} : (vector<4xi32>, vector<4xi32>, vector<64xi32>) -> vector<64xi32> // ---- Wave64 ----- // f16 -> f32 - // CHECK: call <4 x float> @llvm.amdgcn.wmma.f32.16x16x16.f16.v4f32.v16f16(<16 x half> %{{.*}}, <16 x half> %{{.*}}, <4 x float> %{{.*}}) + // CHECK: call <4 x float> @llvm.amdgcn.wmma.f32.16x16x16.f16.v4f32.v16f16(<16 x half> %{{.*}} <16 x half> %{{.*}} <4 x float> %{{.*}}) %r7 = rocdl.wmma.f32.16x16x16.f16 %arg1, %arg1, %arg6 : (vector<16xf16>, vector<16xf16>, vector<4xf32>) -> vector<4xf32> // bf16 -> f32 - // CHECK: call <4 x float> @llvm.amdgcn.wmma.f32.16x16x16.bf16.v4f32.v16i16(<16 x i16> %{{.*}}, <16 x i16> %{{.*}}, <4 x float> %{{.*}}) + // CHECK: call <4 x float> @llvm.amdgcn.wmma.f32.16x16x16.bf16.v4f32.v16i16(<16 x i16> %{{.*}} <16 x i16> %{{.*}} <4 x float> %{{.*}}) %r8 = rocdl.wmma.f32.16x16x16.bf16 %arg2, %arg2, %arg6 : (vector<16xi16>, vector<16xi16>, vector<4xf32>) -> vector<4xf32> // f16 -> f16 (OPSEL = {0,1}) - // CHECK: call <8 x half> @llvm.amdgcn.wmma.f16.16x16x16.f16.v8f16.v16f16(<16 x half> %{{.*}}, <16 x half> %{{.*}}, <8 x half> %{{.*}}, i1 {{.*}}) - %r9 = rocdl.wmma.f16.16x16x16.f16 %arg1, %arg1, %arg7, %zero : (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1) -> vector<8xf16> + // CHECK: call <8 x half> @llvm.amdgcn.wmma.f16.16x16x16.f16.v8f16.v16f16(<16 x half> %{{.*}} <16 x half> %{{.*}} <8 x half> %{{.*}} i1 false) + %r9 = rocdl.wmma.f16.16x16x16.f16 %arg1, %arg1, %arg7 {opsel = false} : (vector<16xf16>, vector<16xf16>, vector<8xf16>) -> vector<8xf16> // bf16 -> bf16 (OPSEL = {0,1}) - // CHECK: call <8 x i16> @llvm.amdgcn.wmma.bf16.16x16x16.bf16.v8i16.v16i16(<16 x i16> %{{.*}}, <16 x i16> %{{.*}}, <8 x i16> %{{.*}}, i1 {{.*}}) - %r11 = rocdl.wmma.bf16.16x16x16.bf16 %arg2, %arg2, %arg8, %zero : (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1) -> vector<8xi16> + // CHECK: call <8 x i16> @llvm.amdgcn.wmma.bf16.16x16x16.bf16.v8i16.v16i16(<16 x i16> %{{.*}} <16 x i16> %{{.*}} <8 x i16> %{{.*}} i1 false) + %r11 = rocdl.wmma.bf16.16x16x16.bf16 %arg2, %arg2, %arg8 {opsel = false} : (vector<16xi16>, vector<16xi16>, vector<8xi16>) -> vector<8xi16> // int8 -> int32 (signA = {0,1}, signB = {0,1}, clamp = {0,1}) - // CHECK: call <4 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu8.v4i32.v4i32(i1 {{.*}}, <4 x i32> %{{.*}}, i1 {{.*}}, <4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i1 {{.*}}) - %r12 = rocdl.wmma.i32.16x16x16.iu8 %zero, %arg5, %zero, %arg5, %arg5, %zero : (i1, vector<4xi32>, i1, vector<4xi32>, vector<4xi32>, i1) -> vector<4xi32> + // CHECK: call <4 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu8.v4i32.v4i32(i1 false, <4 x i32> %{{.*}} i1 false, <4 x i32> %{{.*}} <4 x i32> %{{.*}} i1 true) + %r12 = rocdl.wmma.i32.16x16x16.iu8 %arg5, %arg5, %arg5 {signA = false, signB = false, clamp = true} : (vector<4xi32>, vector<4xi32>, vector<4xi32>) -> vector<4xi32> // int4 -> int32 (signA = {0,1}, signB = {0,1}, clamp = {0,1}) - // CHECK: call <4 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu4.v4i32.v2i32(i1 {{.*}}, <2 x i32> %{{.*}}, i1 {{.*}}, <2 x i32> %{{.*}}, <4 x i32> %{{.*}}, i1 {{.*}}) - %r13 = rocdl.wmma.i32.16x16x16.iu4 %zero, %arg4, %zero, %arg4, %arg5, %zero : (i1, vector<2xi32>, i1, vector<2xi32>, vector<4xi32>, i1) -> vector<4xi32> + // CHECK: call <4 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu4.v4i32.v2i32(i1 false, <2 x i32> %{{.*}} i1 false, <2 x i32> %{{.*}} <4 x i32> %{{.*}} i1 true) + %r13 = rocdl.wmma.i32.16x16x16.iu4 %arg4, %arg4, %arg5 {signA = false, signB = false, clamp = true} : (vector<2xi32>, vector<2xi32>, vector<4xi32>) -> vector<4xi32> llvm.return %r0 : vector<8xf32> } diff --git a/mlir/test/python/dialects/rocdl.py b/mlir/test/python/dialects/rocdl.py index a4a50afa966c7..c73a536e03820 100644 --- a/mlir/test/python/dialects/rocdl.py +++ b/mlir/test/python/dialects/rocdl.py @@ -29,13 +29,12 @@ def testSmoke(): a_frag = arith.constant(v16f32, f32_array) b_frag = arith.constant(v16f32, f32_array) c_frag = arith.constant(v16f32, f32_array) - false = arith.constant(T.bool(), False) - c_frag = rocdl.wmma_f16_16x16x16_f16(v16f32, [a_frag, b_frag, c_frag, false]) - # CHECK: %{{.*}} = rocdl.wmma.f16.16x16x16.f16 + c_frag = rocdl.wmma_f16_16x16x16_f16(v16f32, a_frag, b_frag, c_frag, opsel=False) + # CHECK: %{{.*}} = "rocdl.wmma.f16.16x16x16.f16" print(c_frag) assert isinstance(c_frag, OpView) - # CHECK: Value(%{{.*}} = rocdl.wmma.f16.16x16x16.f16 - c_frag = rocdl.wmma_f16_16x16x16_f16_(v16f32, [a_frag, b_frag, c_frag, false]) + # CHECK: Value(%{{.*}} = "rocdl.wmma.f16.16x16x16.f16" + c_frag = rocdl.wmma_f16_16x16x16_f16_(v16f32, a_frag, b_frag, c_frag, opsel=False) print(c_frag) assert isinstance(c_frag, Value)