Skip to content

Commit d19f53d

Browse files
committed
triplet-ext-script
1 parent b209998 commit d19f53d

File tree

2 files changed

+294
-0
lines changed

2 files changed

+294
-0
lines changed

llvm/docs/CommandGuide/llvm-ir2vec.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ embedding training (see
5050
<https://github.com/thunlp/OpenKE/tree/OpenKE-PyTorch?tab=readme-ov-file#data-format>
5151
for details).
5252

53+
See `llvm/utils/mlgo-utils/IR2Vec/generateTriplets.py` for more details on how
54+
these two modes are used to generate the triplets and entity mappings.
55+
5356
Triplet Generation Mode
5457
~~~~~~~~~~~~~~~~~~~~~~~
5558

Lines changed: 291 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,291 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
"""IR2Vec Triplet Generator
5+
6+
Generates IR2Vec triplets by applying random optimization levels to LLVM IR files
7+
and extracting triplets using llvm-ir2vec. Automatically generates preprocessed
8+
files: entity2id.txt, relation2id.txt, and train2id.txt.
9+
10+
Usage:
11+
python generateTriplets.py <llvm_build_dir> <num_optimizations> <ll_file_list> <output_dir>
12+
"""
13+
14+
import argparse
15+
import logging
16+
import os
17+
import random
18+
import subprocess
19+
import sys
20+
from concurrent.futures import ThreadPoolExecutor, as_completed
21+
from pathlib import Path
22+
from typing import List, Set, Tuple
23+
24+
# Configuration
25+
OPT_LEVELS = ["O0", "O1", "O2", "O3", "Os", "Oz"]
26+
DEFAULT_MAX_WORKERS = 100
27+
28+
logger = logging.getLogger(__name__)
29+
30+
31+
class TripletResult:
32+
"""Result from processing a single LLVM IR file"""
33+
34+
__slots__ = ["triplets", "max_relation"]
35+
36+
def __init__(self, triplets: Set[str], max_relation: int):
37+
self.triplets = triplets
38+
self.max_relation = max_relation
39+
40+
41+
class IR2VecTripletGenerator:
42+
"""Main class for generating IR2Vec triplets"""
43+
44+
def __init__(
45+
self,
46+
llvm_build_dir: Path,
47+
num_optimizations: int,
48+
output_dir: Path,
49+
max_workers: int = DEFAULT_MAX_WORKERS,
50+
):
51+
self.llvm_build_dir = llvm_build_dir
52+
self.num_optimizations = num_optimizations
53+
self.output_dir = output_dir
54+
self.max_workers = max_workers
55+
56+
# Tool paths
57+
self.opt_binary = os.path.join(llvm_build_dir, "bin", "opt")
58+
self.ir2vec_binary = os.path.join(llvm_build_dir, "bin", "llvm-ir2vec")
59+
60+
self._validate_setup()
61+
62+
def _validate_setup(self):
63+
"""Validate that all required tools and paths exist"""
64+
if not self.llvm_build_dir.exists():
65+
raise FileNotFoundError(
66+
f"LLVM build directory not found: {self.llvm_build_dir}"
67+
)
68+
69+
if not os.path.isfile(self.opt_binary) or not os.access(
70+
self.opt_binary, os.X_OK
71+
):
72+
raise FileNotFoundError(
73+
f"opt binary not found or not executable: {self.opt_binary}"
74+
)
75+
76+
if not os.path.isfile(self.ir2vec_binary) or not os.access(
77+
self.ir2vec_binary, os.X_OK
78+
):
79+
raise FileNotFoundError(
80+
f"llvm-ir2vec binary not found or not executable: {self.ir2vec_binary}"
81+
)
82+
83+
if not (1 <= self.num_optimizations <= len(OPT_LEVELS)):
84+
raise ValueError(
85+
f"Number of optimizations must be between 1-{len(OPT_LEVELS)}"
86+
)
87+
88+
self.output_dir.mkdir(parents=True, exist_ok=True)
89+
90+
def _select_optimization_levels(self) -> List[str]:
91+
"""Select unique random optimization levels"""
92+
return random.sample(OPT_LEVELS, self.num_optimizations)
93+
94+
def _process_single_file(self, input_file: Path) -> TripletResult:
95+
"""Process a single LLVM IR file with multiple optimization levels"""
96+
all_triplets = set()
97+
max_relation = 1
98+
opt_levels = self._select_optimization_levels()
99+
100+
for opt_level in opt_levels:
101+
try:
102+
triplets, file_max_relation = self._run_pipeline(input_file, opt_level)
103+
if triplets:
104+
all_triplets.update(triplets)
105+
max_relation = max(max_relation, file_max_relation)
106+
logger.debug(
107+
f"Generated {len(triplets)} triplets for {input_file} with {opt_level}"
108+
)
109+
except Exception as e:
110+
logger.warning(f"Error processing {input_file} with {opt_level}: {e}")
111+
112+
return TripletResult(all_triplets, max_relation)
113+
114+
def _run_pipeline(self, input_file: Path, opt_level: str) -> Tuple[Set[str], int]:
115+
"""Run opt | llvm-ir2vec pipeline elegantly."""
116+
pipeline_cmd = (
117+
f'"{self.opt_binary}" -{opt_level} "{input_file}" -o - | '
118+
f'"{self.ir2vec_binary}" --mode=triplets - -o -'
119+
)
120+
121+
try:
122+
result = subprocess.run(
123+
pipeline_cmd, shell=True, capture_output=True, text=True, check=True
124+
)
125+
return self._parse_triplet_output(result.stdout)
126+
except subprocess.CalledProcessError:
127+
return set(), 1
128+
129+
def _parse_triplet_output(self, output: str) -> Tuple[Set[str], int]:
130+
"""Parse triplet output and extract max relation"""
131+
if not output.strip():
132+
return set(), 1
133+
134+
lines = output.strip().split("\n")
135+
max_relation = 1
136+
137+
# Extract max relation from metadata line
138+
if lines and lines[0].startswith("MAX_RELATION="):
139+
max_relation = int(lines[0].split("=")[1])
140+
lines = lines[1:]
141+
142+
# Remove duplicate triplets by converting to a set
143+
return set(lines), max_relation
144+
145+
def generate_triplets(self, file_list: Path) -> None:
146+
"""Main method to generate triplets from a list of LLVM IR files"""
147+
input_files = self._read_file_list(file_list)
148+
logger.info(
149+
f"Processing {len(input_files)} files with {self.num_optimizations} "
150+
f"optimization levels using {self.max_workers} workers"
151+
)
152+
153+
all_triplets = set()
154+
global_max_relation = 1
155+
156+
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
157+
future_to_file = {
158+
executor.submit(self._process_single_file, file): file
159+
for file in input_files
160+
}
161+
162+
for future in as_completed(future_to_file):
163+
try:
164+
result = future.result()
165+
all_triplets.update(result.triplets)
166+
global_max_relation = max(global_max_relation, result.max_relation)
167+
except Exception as e:
168+
file_path = future_to_file[future]
169+
logger.error(f"Error processing {file_path}: {e}")
170+
171+
self._generate_output_files(all_triplets, global_max_relation)
172+
logger.info("Processing completed successfully")
173+
174+
def _read_file_list(self, file_list: Path) -> List[Path]:
175+
"""Read and validate the list of input files"""
176+
input_files = []
177+
with open(file_list, "r") as f:
178+
for line_num, line in enumerate(f, 1):
179+
if line := line.strip():
180+
file_path = Path(line)
181+
if file_path.exists():
182+
input_files.append(file_path)
183+
else:
184+
logger.warning(f"File not found (line {line_num}): {file_path}")
185+
186+
if not input_files:
187+
raise ValueError("No valid input files found")
188+
return input_files
189+
190+
def _generate_output_files(self, all_triplets: Set[str], max_relation: int) -> None:
191+
"""Generate the final output files"""
192+
logger.info(f"Generating output files with {len(all_triplets)} unique triplets")
193+
194+
# Write all output files -- train2id.txt, entity2id.txt, relation2id.txt
195+
train2id_file = os.path.join(self.output_dir, "train2id.txt")
196+
entity2id_file = os.path.join(self.output_dir, "entity2id.txt")
197+
relation2id_file = os.path.join(self.output_dir, "relation2id.txt")
198+
199+
with open(train2id_file, "w") as f:
200+
f.write(f"{len(all_triplets)}\n")
201+
f.writelines(f"{triplet}\n" for triplet in all_triplets)
202+
203+
self._generate_entity2id(entity2id_file)
204+
self._generate_relation2id(relation2id_file, max_relation)
205+
206+
def _generate_entity2id(self, output_file: Path) -> None:
207+
"""Generate entity2id.txt using llvm-ir2vec"""
208+
subprocess.run(
209+
[str(self.ir2vec_binary), "--mode=entities", "-o", str(output_file)],
210+
check=True,
211+
capture_output=True,
212+
)
213+
214+
def _generate_relation2id(self, output_file: Path, max_relation: int) -> None:
215+
"""Generate relation2id.txt from max relation"""
216+
max_relation = max(max_relation, 1) # At least Type and Next relations
217+
num_relations = max_relation + 1
218+
219+
with open(output_file, "w") as f:
220+
f.write(f"{num_relations}\n")
221+
f.write("Type\t0\n")
222+
f.write("Next\t1\n")
223+
f.writelines(f"Arg{i-2}\t{i}\n" for i in range(2, num_relations))
224+
225+
226+
def main():
227+
"""Main entry point"""
228+
parser = argparse.ArgumentParser(
229+
description="Generate IR2Vec triplets from LLVM IR files",
230+
formatter_class=argparse.RawDescriptionHelpFormatter,
231+
)
232+
233+
parser.add_argument(
234+
"llvm_build_dir", type=Path, help="Path to LLVM build directory"
235+
)
236+
parser.add_argument(
237+
"num_optimizations",
238+
type=int,
239+
help="Number of optimization levels to apply (1-6)",
240+
)
241+
parser.add_argument(
242+
"ll_file_list",
243+
type=Path,
244+
help="File containing list of LLVM IR files to process",
245+
)
246+
parser.add_argument(
247+
"output_dir", type=Path, help="Output directory for generated files"
248+
)
249+
parser.add_argument(
250+
"-j",
251+
"--max-workers",
252+
type=int,
253+
default=DEFAULT_MAX_WORKERS,
254+
help=f"Maximum number of parallel workers (default: {DEFAULT_MAX_WORKERS})",
255+
)
256+
parser.add_argument(
257+
"-v", "--verbose", action="store_true", help="Enable debug logging"
258+
)
259+
parser.add_argument(
260+
"-q", "--quiet", action="store_true", help="Suppress all output except errors"
261+
)
262+
263+
args = parser.parse_args()
264+
265+
# Configure logging
266+
level = (
267+
logging.ERROR
268+
if args.quiet
269+
else (logging.DEBUG if args.verbose else logging.INFO)
270+
)
271+
logging.basicConfig(
272+
level=level,
273+
format="[%(asctime)s] %(levelname)s: %(message)s",
274+
datefmt="%H:%M:%S",
275+
)
276+
277+
try:
278+
generator = IR2VecTripletGenerator(
279+
args.llvm_build_dir,
280+
args.num_optimizations,
281+
args.output_dir,
282+
args.max_workers,
283+
)
284+
generator.generate_triplets(args.ll_file_list)
285+
except Exception as e:
286+
logger.error(f"Error: {e}")
287+
sys.exit(1)
288+
289+
290+
if __name__ == "__main__":
291+
main()

0 commit comments

Comments
 (0)