diff --git a/llvm/test/lit.cfg.py b/llvm/test/lit.cfg.py index 771d9245368b1..8a1b001695edc 100644 --- a/llvm/test/lit.cfg.py +++ b/llvm/test/lit.cfg.py @@ -93,6 +93,13 @@ def get_asan_rtlib(): config.substitutions.append(("%exeext", config.llvm_exe_ext)) config.substitutions.append(("%llvm_src_root", config.llvm_src_root)) +# Add IR2Vec test vocabulary path substitution +config.substitutions.append( + ( + "%ir2vec_test_vocab_dir", + os.path.join(config.test_source_root, "Analysis", "IR2Vec", "Inputs"), + ) +) lli_args = [] # The target triple used by default by lli is the process target triple (some diff --git a/llvm/test/tools/llvm-ir2vec/embeddings.ll b/llvm/test/tools/llvm-ir2vec/embeddings.ll new file mode 100644 index 0000000000000..993ea865170f9 --- /dev/null +++ b/llvm/test/tools/llvm-ir2vec/embeddings.ll @@ -0,0 +1,73 @@ +; RUN: llvm-ir2vec --mode=embeddings --ir2vec-vocab-path=%ir2vec_test_vocab_dir/dummy_3D_nonzero_opc_vocab.json %s | FileCheck %s -check-prefix=CHECK-DEFAULT +; RUN: llvm-ir2vec --mode=embeddings --level=func --ir2vec-vocab-path=%ir2vec_test_vocab_dir/dummy_3D_nonzero_opc_vocab.json %s | FileCheck %s -check-prefix=CHECK-FUNC-LEVEL +; RUN: llvm-ir2vec --mode=embeddings --level=func --function=abc --ir2vec-vocab-path=%ir2vec_test_vocab_dir/dummy_3D_nonzero_opc_vocab.json %s | FileCheck %s -check-prefix=CHECK-FUNC-LEVEL-ABC +; RUN: not llvm-ir2vec --mode=embeddings --level=func --function=def --ir2vec-vocab-path=%ir2vec_test_vocab_dir/dummy_3D_nonzero_opc_vocab.json %s 2>&1 | FileCheck %s -check-prefix=CHECK-FUNC-DEF +; RUN: llvm-ir2vec --mode=embeddings --level=bb --ir2vec-vocab-path=%ir2vec_test_vocab_dir/dummy_3D_nonzero_opc_vocab.json %s | FileCheck %s -check-prefix=CHECK-BB-LEVEL +; RUN: llvm-ir2vec --mode=embeddings --level=bb --function=abc_repeat --ir2vec-vocab-path=%ir2vec_test_vocab_dir/dummy_3D_nonzero_opc_vocab.json %s | FileCheck %s -check-prefix=CHECK-BB-LEVEL-ABC-REPEAT +; RUN: llvm-ir2vec --mode=embeddings --level=inst --function=abc_repeat --ir2vec-vocab-path=%ir2vec_test_vocab_dir/dummy_3D_nonzero_opc_vocab.json %s | FileCheck %s -check-prefix=CHECK-INST-LEVEL-ABC-REPEAT + +define dso_local noundef float @abc(i32 noundef %a, float noundef %b) #0 { +entry: + %a.addr = alloca i32, align 4 + %b.addr = alloca float, align 4 + store i32 %a, ptr %a.addr, align 4 + store float %b, ptr %b.addr, align 4 + %0 = load i32, ptr %a.addr, align 4 + %1 = load i32, ptr %a.addr, align 4 + %mul = mul nsw i32 %0, %1 + %conv = sitofp i32 %mul to float + %2 = load float, ptr %b.addr, align 4 + %add = fadd float %conv, %2 + ret float %add +} + +define dso_local noundef float @abc_repeat(i32 noundef %a, float noundef %b) #0 { +entry: + %a.addr = alloca i32, align 4 + %b.addr = alloca float, align 4 + store i32 %a, ptr %a.addr, align 4 + store float %b, ptr %b.addr, align 4 + %0 = load i32, ptr %a.addr, align 4 + %1 = load i32, ptr %a.addr, align 4 + %mul = mul nsw i32 %0, %1 + %conv = sitofp i32 %mul to float + %2 = load float, ptr %b.addr, align 4 + %add = fadd float %conv, %2 + ret float %add +} + +; CHECK-DEFAULT: Function: abc +; CHECK-DEFAULT-NEXT: [ 878.00 889.00 900.00 ] +; CHECK-DEFAULT-NEXT: Function: abc_repeat +; CHECK-DEFAULT-NEXT: [ 878.00 889.00 900.00 ] + +; CHECK-FUNC-LEVEL: Function: abc +; CHECK-FUNC-LEVEL-NEXT: [ 878.00 889.00 900.00 ] +; CHECK-FUNC-LEVEL-NEXT: Function: abc_repeat +; CHECK-FUNC-LEVEL-NEXT: [ 878.00 889.00 900.00 ] + +; CHECK-FUNC-LEVEL-ABC: Function: abc +; CHECK-FUNC-LEVEL-NEXT-ABC: [ 878.00 889.00 900.00 ] + +; CHECK-FUNC-DEF: Error: Function 'def' not found + +; CHECK-BB-LEVEL: Function: abc +; CHECK-BB-LEVEL-NEXT: entry: [ 878.00 889.00 900.00 ] +; CHECK-BB-LEVEL-NEXT: Function: abc_repeat +; CHECK-BB-LEVEL-NEXT: entry: [ 878.00 889.00 900.00 ] + +; CHECK-BB-LEVEL-ABC-REPEAT: Function: abc_repeat +; CHECK-BB-LEVEL-ABC-REPEAT-NEXT: entry: [ 878.00 889.00 900.00 ] + +; CHECK-INST-LEVEL-ABC-REPEAT: Function: abc_repeat +; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: %a.addr = alloca i32, align 4 [ 91.00 92.00 93.00 ] +; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: %b.addr = alloca float, align 4 [ 91.00 92.00 93.00 ] +; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: store i32 %a, ptr %a.addr, align 4 [ 97.00 98.00 99.00 ] +; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: store float %b, ptr %b.addr, align 4 [ 97.00 98.00 99.00 ] +; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: %0 = load i32, ptr %a.addr, align 4 [ 94.00 95.00 96.00 ] +; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: %1 = load i32, ptr %a.addr, align 4 [ 94.00 95.00 96.00 ] +; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: %mul = mul nsw i32 %0, %1 [ 49.00 50.00 51.00 ] +; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: %conv = sitofp i32 %mul to float [ 130.00 131.00 132.00 ] +; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: %2 = load float, ptr %b.addr, align 4 [ 94.00 95.00 96.00 ] +; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: %add = fadd float %conv, %2 [ 40.00 41.00 42.00 ] +; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: ret float %add [ 1.00 2.00 3.00 ] diff --git a/llvm/test/tools/llvm-ir2vec/triplets.ll b/llvm/test/tools/llvm-ir2vec/triplets.ll index fa5aaa895406f..d1ef5b388e258 100644 --- a/llvm/test/tools/llvm-ir2vec/triplets.ll +++ b/llvm/test/tools/llvm-ir2vec/triplets.ll @@ -1,4 +1,4 @@ -; RUN: llvm-ir2vec %s | FileCheck %s -check-prefix=TRIPLETS +; RUN: llvm-ir2vec --mode=triplets %s | FileCheck %s -check-prefix=TRIPLETS define i32 @simple_add(i32 %a, i32 %b) { entry: diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp index 35e1c995fa4cc..eba8c2e5678b1 100644 --- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp +++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp @@ -9,12 +9,18 @@ /// \file /// This file implements the IR2Vec embedding generation tool. /// -/// Currently supports triplet generation for vocabulary training. -/// Future updates will support embedding generation using trained vocabulary. +/// This tool provides two main functionalities: /// -/// Usage: llvm-ir2vec input.bc -o triplets.txt +/// 1. Triplet Generation Mode (--mode=triplets): +/// Generates triplets (opcode, type, operands) for vocabulary training. +/// Usage: llvm-ir2vec --mode=triplets input.bc -o triplets.txt /// -/// TODO: Add embedding generation mode with vocabulary support +/// 2. Embedding Generation Mode (--mode=embeddings): +/// Generates IR2Vec embeddings using a trained vocabulary. +/// Usage: llvm-ir2vec --mode=embeddings --ir2vec-vocab-path=vocab.json +/// --level=func input.bc -o embeddings.txt Levels: --level=inst +/// (instructions), --level=bb (basic blocks), --level=func (functions) +/// (See IR2Vec.cpp for more embedding generation options) /// //===----------------------------------------------------------------------===// @@ -24,6 +30,8 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" +#include "llvm/IR/PassInstrumentation.h" +#include "llvm/IR/PassManager.h" #include "llvm/IR/Type.h" #include "llvm/IRReader/IRReader.h" #include "llvm/Support/CommandLine.h" @@ -33,11 +41,11 @@ #include "llvm/Support/SourceMgr.h" #include "llvm/Support/raw_ostream.h" -using namespace llvm; -using namespace ir2vec; - #define DEBUG_TYPE "ir2vec" +namespace llvm { +namespace ir2vec { + static cl::OptionCategory IR2VecToolCategory("IR2Vec Tool Options"); static cl::opt InputFilename(cl::Positional, @@ -50,16 +58,63 @@ static cl::opt OutputFilename("o", cl::desc("Output filename"), cl::init("-"), cl::cat(IR2VecToolCategory)); +enum ToolMode { + TripletMode, // Generate triplets for vocabulary training + EmbeddingMode // Generate embeddings using trained vocabulary +}; + +static cl::opt + Mode("mode", cl::desc("Tool operation mode:"), + cl::values(clEnumValN(TripletMode, "triplets", + "Generate triplets for vocabulary training"), + clEnumValN(EmbeddingMode, "embeddings", + "Generate embeddings using trained vocabulary")), + cl::init(EmbeddingMode), cl::cat(IR2VecToolCategory)); + +static cl::opt + FunctionName("function", cl::desc("Process specific function only"), + cl::value_desc("name"), cl::Optional, cl::init(""), + cl::cat(IR2VecToolCategory)); + +enum EmbeddingLevel { + InstructionLevel, // Generate instruction-level embeddings + BasicBlockLevel, // Generate basic block-level embeddings + FunctionLevel // Generate function-level embeddings +}; + +static cl::opt + Level("level", cl::desc("Embedding generation level (for embedding mode):"), + cl::values(clEnumValN(InstructionLevel, "inst", + "Generate instruction-level embeddings"), + clEnumValN(BasicBlockLevel, "bb", + "Generate basic block-level embeddings"), + clEnumValN(FunctionLevel, "func", + "Generate function-level embeddings")), + cl::init(FunctionLevel), cl::cat(IR2VecToolCategory)); + namespace { -/// Helper class for collecting IR information and generating triplets +/// Helper class for collecting IR triplets and generating embeddings class IR2VecTool { private: Module &M; + ModuleAnalysisManager MAM; + const Vocabulary *Vocab = nullptr; public: explicit IR2VecTool(Module &M) : M(M) {} + /// Initialize the IR2Vec vocabulary analysis + bool initializeVocabulary() { + // Register and run the IR2Vec vocabulary analysis + // The vocabulary file path is specified via --ir2vec-vocab-path global + // option + MAM.registerPass([&] { return PassInstrumentationAnalysis(); }); + MAM.registerPass([&] { return IR2VecVocabAnalysis(); }); + Vocab = &MAM.getResult(M); + return Vocab->isValid(); + } + /// Generate triplets for the entire module void generateTriplets(raw_ostream &OS) const { for (const Function &F : M) @@ -81,6 +136,68 @@ class IR2VecTool { OS << LocalOutput; } + /// Generate embeddings for the entire module + void generateEmbeddings(raw_ostream &OS) const { + if (!Vocab->isValid()) { + OS << "Error: Vocabulary is not valid. IR2VecTool not initialized.\n"; + return; + } + + for (const Function &F : M) + generateEmbeddings(F, OS); + } + + /// Generate embeddings for a single function + void generateEmbeddings(const Function &F, raw_ostream &OS) const { + if (F.isDeclaration()) { + OS << "Function " << F.getName() << " is a declaration, skipping.\n"; + return; + } + + // Create embedder for this function + assert(Vocab->isValid() && "Vocabulary is not valid"); + auto Emb = Embedder::create(IR2VecKind::Symbolic, F, *Vocab); + if (!Emb) { + OS << "Error: Failed to create embedder for function " << F.getName() + << "\n"; + return; + } + + OS << "Function: " << F.getName() << "\n"; + + // Generate embeddings based on the specified level + switch (Level) { + case FunctionLevel: { + Emb->getFunctionVector().print(OS); + break; + } + case BasicBlockLevel: { + const auto &BBVecMap = Emb->getBBVecMap(); + for (const BasicBlock &BB : F) { + auto It = BBVecMap.find(&BB); + if (It != BBVecMap.end()) { + OS << BB.getName() << ":"; + It->second.print(OS); + } + } + break; + } + case InstructionLevel: { + const auto &InstMap = Emb->getInstVecMap(); + for (const BasicBlock &BB : F) { + for (const Instruction &I : BB) { + auto It = InstMap.find(&I); + if (It != InstMap.end()) { + I.print(OS); + It->second.print(OS); + } + } + } + break; + } + } + } + private: /// Process a single basic block for triplet generation void traverseBasicBlock(const BasicBlock &BB, raw_string_ostream &OS) const { @@ -105,23 +222,70 @@ class IR2VecTool { Error processModule(Module &M, raw_ostream &OS) { IR2VecTool Tool(M); - Tool.generateTriplets(OS); + if (Mode == EmbeddingMode) { + // Initialize vocabulary for embedding generation + // Note: Requires --ir2vec-vocab-path option to be set + if (!Tool.initializeVocabulary()) + return createStringError( + errc::invalid_argument, + "Failed to initialize IR2Vec vocabulary. " + "Make sure to specify --ir2vec-vocab-path for embedding mode."); + + if (!FunctionName.empty()) { + // Process single function + if (const Function *F = M.getFunction(FunctionName)) + Tool.generateEmbeddings(*F, OS); + else + return createStringError(errc::invalid_argument, + "Function '%s' not found", + FunctionName.c_str()); + } else { + // Process all functions + Tool.generateEmbeddings(OS); + } + } else { + // Triplet generation mode - no vocabulary needed + if (!FunctionName.empty()) + // Process single function + if (const Function *F = M.getFunction(FunctionName)) + Tool.generateTriplets(*F, OS); + else + return createStringError(errc::invalid_argument, + "Function '%s' not found", + FunctionName.c_str()); + else + // Process all functions + Tool.generateTriplets(OS); + } return Error::success(); } - -} // anonymous namespace +} // namespace +} // namespace ir2vec +} // namespace llvm int main(int argc, char **argv) { + using namespace llvm; + using namespace llvm::ir2vec; + InitLLVM X(argc, argv); cl::HideUnrelatedOptions(IR2VecToolCategory); cl::ParseCommandLineOptions( argc, argv, - "IR2Vec - Triplet Generation Tool\n" - "Generates triplets for vocabulary training from LLVM IR.\n" - "Future updates will support embedding generation.\n\n" + "IR2Vec - Embedding Generation Tool\n" + "Generates embeddings for a given LLVM IR and " + "supports triplet generation for vocabulary " + "training and embedding generation.\n\n" "Usage:\n" - " llvm-ir2vec input.bc -o triplets.txt\n"); + " Triplet mode: llvm-ir2vec --mode=triplets input.bc\n" + " Embedding mode: llvm-ir2vec --mode=embeddings " + "--ir2vec-vocab-path=vocab.json --level=func input.bc\n" + " Levels: --level=inst (instructions), --level=bb (basic blocks), " + "--level=func (functions)\n"); + + // Validate command line options + if (Mode == TripletMode && Level.getNumOccurrences() > 0) + errs() << "Warning: --level option is ignored in triplet mode\n"; // Parse the input LLVM IR file SMDiagnostic Err;