|
| 1 | +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 2 | +# See https://llvm.org/LICENSE.txt for license information. |
| 3 | +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 4 | +"""IR2Vec Triplet Generator |
| 5 | +
|
| 6 | +Generates IR2Vec triplets by applying random optimization levels to LLVM IR files |
| 7 | +and extracting triplets using llvm-ir2vec. Automatically generates preprocessed |
| 8 | +files: entity2id.txt, relation2id.txt, and train2id.txt. |
| 9 | +
|
| 10 | +Usage: |
| 11 | + python generateTriplets.py <llvm_build_dir> <num_optimizations> <ll_file_list> <output_dir> |
| 12 | +""" |
| 13 | + |
| 14 | +import argparse |
| 15 | +import logging |
| 16 | +import os |
| 17 | +import random |
| 18 | +import subprocess |
| 19 | +import sys |
| 20 | +from concurrent.futures import ThreadPoolExecutor, as_completed |
| 21 | +from pathlib import Path |
| 22 | +from typing import List, Set, Tuple |
| 23 | + |
| 24 | +# Configuration |
| 25 | +OPT_LEVELS = ["O0", "O1", "O2", "O3", "Os", "Oz"] |
| 26 | +DEFAULT_MAX_WORKERS = 100 |
| 27 | + |
| 28 | +logger = logging.getLogger(__name__) |
| 29 | + |
| 30 | + |
| 31 | +class TripletResult: |
| 32 | + """Result from processing a single LLVM IR file""" |
| 33 | + |
| 34 | + __slots__ = ["triplets", "max_relation"] |
| 35 | + |
| 36 | + def __init__(self, triplets: Set[str], max_relation: int): |
| 37 | + self.triplets = triplets |
| 38 | + self.max_relation = max_relation |
| 39 | + |
| 40 | + |
| 41 | +class IR2VecTripletGenerator: |
| 42 | + """Main class for generating IR2Vec triplets""" |
| 43 | + |
| 44 | + def __init__( |
| 45 | + self, |
| 46 | + llvm_build_dir: Path, |
| 47 | + num_optimizations: int, |
| 48 | + output_dir: Path, |
| 49 | + max_workers: int = DEFAULT_MAX_WORKERS, |
| 50 | + ): |
| 51 | + self.llvm_build_dir = llvm_build_dir |
| 52 | + self.num_optimizations = num_optimizations |
| 53 | + self.output_dir = output_dir |
| 54 | + self.max_workers = max_workers |
| 55 | + |
| 56 | + # Tool paths |
| 57 | + self.opt_binary = os.path.join(llvm_build_dir, "bin", "opt") |
| 58 | + self.ir2vec_binary = os.path.join(llvm_build_dir, "bin", "llvm-ir2vec") |
| 59 | + |
| 60 | + self._validate_setup() |
| 61 | + |
| 62 | + def _validate_setup(self): |
| 63 | + """Validate that all required tools and paths exist""" |
| 64 | + if not self.llvm_build_dir.exists(): |
| 65 | + raise FileNotFoundError( |
| 66 | + f"LLVM build directory not found: {self.llvm_build_dir}" |
| 67 | + ) |
| 68 | + |
| 69 | + if not os.path.isfile(self.opt_binary) or not os.access( |
| 70 | + self.opt_binary, os.X_OK |
| 71 | + ): |
| 72 | + raise FileNotFoundError( |
| 73 | + f"opt binary not found or not executable: {self.opt_binary}" |
| 74 | + ) |
| 75 | + |
| 76 | + if not os.path.isfile(self.ir2vec_binary) or not os.access( |
| 77 | + self.ir2vec_binary, os.X_OK |
| 78 | + ): |
| 79 | + raise FileNotFoundError( |
| 80 | + f"llvm-ir2vec binary not found or not executable: {self.ir2vec_binary}" |
| 81 | + ) |
| 82 | + |
| 83 | + if not (1 <= self.num_optimizations <= len(OPT_LEVELS)): |
| 84 | + raise ValueError( |
| 85 | + f"Number of optimizations must be between 1-{len(OPT_LEVELS)}" |
| 86 | + ) |
| 87 | + |
| 88 | + self.output_dir.mkdir(parents=True, exist_ok=True) |
| 89 | + |
| 90 | + def _select_optimization_levels(self) -> List[str]: |
| 91 | + """Select unique random optimization levels""" |
| 92 | + return random.sample(OPT_LEVELS, self.num_optimizations) |
| 93 | + |
| 94 | + def _process_single_file(self, input_file: Path) -> TripletResult: |
| 95 | + """Process a single LLVM IR file with multiple optimization levels""" |
| 96 | + all_triplets = set() |
| 97 | + max_relation = 1 |
| 98 | + opt_levels = self._select_optimization_levels() |
| 99 | + |
| 100 | + for opt_level in opt_levels: |
| 101 | + try: |
| 102 | + triplets, file_max_relation = self._run_pipeline(input_file, opt_level) |
| 103 | + if triplets: |
| 104 | + all_triplets.update(triplets) |
| 105 | + max_relation = max(max_relation, file_max_relation) |
| 106 | + logger.debug( |
| 107 | + f"Generated {len(triplets)} triplets for {input_file} with {opt_level}" |
| 108 | + ) |
| 109 | + except Exception as e: |
| 110 | + logger.warning(f"Error processing {input_file} with {opt_level}: {e}") |
| 111 | + |
| 112 | + return TripletResult(all_triplets, max_relation) |
| 113 | + |
| 114 | + def _run_pipeline(self, input_file: Path, opt_level: str) -> Tuple[Set[str], int]: |
| 115 | + """Run opt | llvm-ir2vec pipeline elegantly.""" |
| 116 | + pipeline_cmd = ( |
| 117 | + f'"{self.opt_binary}" -{opt_level} "{input_file}" -o - | ' |
| 118 | + f'"{self.ir2vec_binary}" --mode=triplets - -o -' |
| 119 | + ) |
| 120 | + |
| 121 | + try: |
| 122 | + result = subprocess.run( |
| 123 | + pipeline_cmd, shell=True, capture_output=True, text=True, check=True |
| 124 | + ) |
| 125 | + return self._parse_triplet_output(result.stdout) |
| 126 | + except subprocess.CalledProcessError: |
| 127 | + return set(), 1 |
| 128 | + |
| 129 | + def _parse_triplet_output(self, output: str) -> Tuple[Set[str], int]: |
| 130 | + """Parse triplet output and extract max relation""" |
| 131 | + if not output.strip(): |
| 132 | + return set(), 1 |
| 133 | + |
| 134 | + lines = output.strip().split("\n") |
| 135 | + max_relation = 1 |
| 136 | + |
| 137 | + # Extract max relation from metadata line |
| 138 | + if lines and lines[0].startswith("MAX_RELATION="): |
| 139 | + max_relation = int(lines[0].split("=")[1]) |
| 140 | + lines = lines[1:] |
| 141 | + |
| 142 | + # Remove duplicate triplets by converting to a set |
| 143 | + return set(lines), max_relation |
| 144 | + |
| 145 | + def generate_triplets(self, file_list: Path) -> None: |
| 146 | + """Main method to generate triplets from a list of LLVM IR files""" |
| 147 | + input_files = self._read_file_list(file_list) |
| 148 | + logger.info( |
| 149 | + f"Processing {len(input_files)} files with {self.num_optimizations} " |
| 150 | + f"optimization levels using {self.max_workers} workers" |
| 151 | + ) |
| 152 | + |
| 153 | + all_triplets = set() |
| 154 | + global_max_relation = 1 |
| 155 | + |
| 156 | + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: |
| 157 | + future_to_file = { |
| 158 | + executor.submit(self._process_single_file, file): file |
| 159 | + for file in input_files |
| 160 | + } |
| 161 | + |
| 162 | + for future in as_completed(future_to_file): |
| 163 | + try: |
| 164 | + result = future.result() |
| 165 | + all_triplets.update(result.triplets) |
| 166 | + global_max_relation = max(global_max_relation, result.max_relation) |
| 167 | + except Exception as e: |
| 168 | + file_path = future_to_file[future] |
| 169 | + logger.error(f"Error processing {file_path}: {e}") |
| 170 | + |
| 171 | + self._generate_output_files(all_triplets, global_max_relation) |
| 172 | + logger.info("Processing completed successfully") |
| 173 | + |
| 174 | + def _read_file_list(self, file_list: Path) -> List[Path]: |
| 175 | + """Read and validate the list of input files""" |
| 176 | + input_files = [] |
| 177 | + with open(file_list, "r") as f: |
| 178 | + for line_num, line in enumerate(f, 1): |
| 179 | + if line := line.strip(): |
| 180 | + file_path = Path(line) |
| 181 | + if file_path.exists(): |
| 182 | + input_files.append(file_path) |
| 183 | + else: |
| 184 | + logger.warning(f"File not found (line {line_num}): {file_path}") |
| 185 | + |
| 186 | + if not input_files: |
| 187 | + raise ValueError("No valid input files found") |
| 188 | + return input_files |
| 189 | + |
| 190 | + def _generate_output_files(self, all_triplets: Set[str], max_relation: int) -> None: |
| 191 | + """Generate the final output files""" |
| 192 | + logger.info(f"Generating output files with {len(all_triplets)} unique triplets") |
| 193 | + |
| 194 | + # Write all output files -- train2id.txt, entity2id.txt, relation2id.txt |
| 195 | + train2id_file = os.path.join(self.output_dir, "train2id.txt") |
| 196 | + entity2id_file = os.path.join(self.output_dir, "entity2id.txt") |
| 197 | + relation2id_file = os.path.join(self.output_dir, "relation2id.txt") |
| 198 | + |
| 199 | + with open(train2id_file, "w") as f: |
| 200 | + f.write(f"{len(all_triplets)}\n") |
| 201 | + f.writelines(f"{triplet}\n" for triplet in all_triplets) |
| 202 | + |
| 203 | + self._generate_entity2id(entity2id_file) |
| 204 | + self._generate_relation2id(relation2id_file, max_relation) |
| 205 | + |
| 206 | + def _generate_entity2id(self, output_file: Path) -> None: |
| 207 | + """Generate entity2id.txt using llvm-ir2vec""" |
| 208 | + subprocess.run( |
| 209 | + [str(self.ir2vec_binary), "--mode=entities", "-o", str(output_file)], |
| 210 | + check=True, |
| 211 | + capture_output=True, |
| 212 | + ) |
| 213 | + |
| 214 | + def _generate_relation2id(self, output_file: Path, max_relation: int) -> None: |
| 215 | + """Generate relation2id.txt from max relation""" |
| 216 | + max_relation = max(max_relation, 1) # At least Type and Next relations |
| 217 | + num_relations = max_relation + 1 |
| 218 | + |
| 219 | + with open(output_file, "w") as f: |
| 220 | + f.write(f"{num_relations}\n") |
| 221 | + f.write("Type\t0\n") |
| 222 | + f.write("Next\t1\n") |
| 223 | + f.writelines(f"Arg{i-2}\t{i}\n" for i in range(2, num_relations)) |
| 224 | + |
| 225 | + |
| 226 | +def main(): |
| 227 | + """Main entry point""" |
| 228 | + parser = argparse.ArgumentParser( |
| 229 | + description="Generate IR2Vec triplets from LLVM IR files", |
| 230 | + formatter_class=argparse.RawDescriptionHelpFormatter, |
| 231 | + ) |
| 232 | + |
| 233 | + parser.add_argument( |
| 234 | + "llvm_build_dir", type=Path, help="Path to LLVM build directory" |
| 235 | + ) |
| 236 | + parser.add_argument( |
| 237 | + "num_optimizations", |
| 238 | + type=int, |
| 239 | + help="Number of optimization levels to apply (1-6)", |
| 240 | + ) |
| 241 | + parser.add_argument( |
| 242 | + "ll_file_list", |
| 243 | + type=Path, |
| 244 | + help="File containing list of LLVM IR files to process", |
| 245 | + ) |
| 246 | + parser.add_argument( |
| 247 | + "output_dir", type=Path, help="Output directory for generated files" |
| 248 | + ) |
| 249 | + parser.add_argument( |
| 250 | + "-j", |
| 251 | + "--max-workers", |
| 252 | + type=int, |
| 253 | + default=DEFAULT_MAX_WORKERS, |
| 254 | + help=f"Maximum number of parallel workers (default: {DEFAULT_MAX_WORKERS})", |
| 255 | + ) |
| 256 | + parser.add_argument( |
| 257 | + "-v", "--verbose", action="store_true", help="Enable debug logging" |
| 258 | + ) |
| 259 | + parser.add_argument( |
| 260 | + "-q", "--quiet", action="store_true", help="Suppress all output except errors" |
| 261 | + ) |
| 262 | + |
| 263 | + args = parser.parse_args() |
| 264 | + |
| 265 | + # Configure logging |
| 266 | + level = ( |
| 267 | + logging.ERROR |
| 268 | + if args.quiet |
| 269 | + else (logging.DEBUG if args.verbose else logging.INFO) |
| 270 | + ) |
| 271 | + logging.basicConfig( |
| 272 | + level=level, |
| 273 | + format="[%(asctime)s] %(levelname)s: %(message)s", |
| 274 | + datefmt="%H:%M:%S", |
| 275 | + ) |
| 276 | + |
| 277 | + try: |
| 278 | + generator = IR2VecTripletGenerator( |
| 279 | + args.llvm_build_dir, |
| 280 | + args.num_optimizations, |
| 281 | + args.output_dir, |
| 282 | + args.max_workers, |
| 283 | + ) |
| 284 | + generator.generate_triplets(args.ll_file_list) |
| 285 | + except Exception as e: |
| 286 | + logger.error(f"Error: {e}") |
| 287 | + sys.exit(1) |
| 288 | + |
| 289 | + |
| 290 | +if __name__ == "__main__": |
| 291 | + main() |
0 commit comments