diff --git a/tools/hardware_utilization_toolkit/batch_ncu_kernels.py b/tools/hardware_utilization_toolkit/batch_ncu_kernels.py new file mode 100644 index 000000000000..628817bde964 --- /dev/null +++ b/tools/hardware_utilization_toolkit/batch_ncu_kernels.py @@ -0,0 +1,503 @@ +#!/usr/bin/env python3 +import argparse +import csv +import os +import re +import shlex +import subprocess +import sys +from dataclasses import dataclass +from typing import List, Optional + +import pandas as pd + + +# ------------------------------- +# Data structure +# ------------------------------- + +@dataclass +class KernelInfo: + name: str + total_time_ms: float + time_pct: float + + +# ------------------------------- +# Helpers for NSys parsing +# ------------------------------- + +def _parse_duration_to_ms(s: str) -> float: + """ + Parse strings like: + '26.880 μs', '443.106 μs', '3.5 ms', '1.2 s' + into milliseconds. + """ + s = str(s).strip() + if not s: + return 0.0 + + parts = s.split() + if len(parts) != 2: + return 0.0 + + val_str, unit = parts + try: + val = float(val_str.replace(",", "")) + except ValueError: + return 0.0 + + unit = unit.lower() + if unit.startswith("μs") or unit.startswith("us"): + # microseconds -> ms + return val / 1000.0 + if unit.startswith("ns"): + # nanoseconds -> ms + return val / 1e6 + if unit.startswith("ms"): + # already ms + return val + if unit.startswith("s"): + # seconds -> ms + return val * 1000.0 + return 0.0 + + +def read_nsys_kernels_from_excel( + xlsx_path: str, + min_time_ms: float = 0.0, + top_k: Optional[int] = 20, +) -> List[KernelInfo]: + """ + Read an NSys Excel file with per-launch rows and aggregate by kernel name. + + Expected columns: + - 'Kernel Name' + - 'Kernel Duration' (e.g. '26.880 μs', '3.5 ms', etc.) + """ + df = pd.read_excel(xlsx_path) + + if "Kernel Name" not in df.columns or "Kernel Duration" not in df.columns: + print( + f"[ERROR] Excel file {xlsx_path} must contain 'Kernel Name' and 'Kernel Duration' columns.", + file=sys.stderr, + ) + return [] + + # Drop rows without kernel names + df = df.dropna(subset=["Kernel Name"]) + + # Convert duration string -> ms + df["dur_ms"] = df["Kernel Duration"].apply(_parse_duration_to_ms) + + # Aggregate by kernel name + grouped = df.groupby("Kernel Name", as_index=False)["dur_ms"].sum() + grouped = grouped.sort_values("dur_ms", ascending=False) + + total_time_ms = grouped["dur_ms"].sum() + if total_time_ms <= 0: + print(f"[WARN] Total kernel time is zero in {xlsx_path}", file=sys.stderr) + return [] + + kernels: List[KernelInfo] = [] + for _, row in grouped.iterrows(): + name = str(row["Kernel Name"]).strip() + t_ms = float(row["dur_ms"]) + if t_ms < min_time_ms: + continue + pct = (t_ms / total_time_ms) * 100.0 + kernels.append(KernelInfo(name=name, total_time_ms=t_ms, time_pct=pct)) + + if top_k is not None and top_k > 0: + kernels = kernels[:top_k] + + return kernels + + +def read_nsys_kernels_from_csv( + csv_path: str, + min_time_ms: float = 0.0, + top_k: Optional[int] = 20, +) -> List[KernelInfo]: + """ + Read an NSys cudaKernSummary CSV file (nsys stats --report cudaKernSummary). + """ + kernels: List[KernelInfo] = [] + + with open(csv_path, newline="") as f: + # Filter out comment lines starting with '#' + lines = [ln for ln in f if ln.strip() and not ln.startswith("#")] + if not lines: + return kernels + + headers = [h.strip() for h in lines[0].strip().split(",")] + reader = csv.DictReader(lines[1:], fieldnames=headers) + + for row in reader: + name = row.get("Name") or row.get("Kernel Name") or "" + name = name.strip() + if not name: + continue + + time_str = row.get("Time (ms)") or row.get("Time (ns)") or "" + pct_str = row.get("Time(%)") or row.get("Time (%)") or "" + if not time_str or not pct_str: + continue + + try: + t_val = float(time_str) + if "Time (ns)" in headers: + total_ms = t_val * 1e-6 + else: + total_ms = t_val + pct = float(pct_str) + except ValueError: + continue + + if total_ms < min_time_ms: + continue + + kernels.append(KernelInfo(name=name, total_time_ms=total_ms, time_pct=pct)) + + kernels.sort(key=lambda k: k.time_pct, reverse=True) + if top_k is not None and top_k > 0: + kernels = kernels[:top_k] + return kernels + + +def read_nsys_kernels( + path: str, + min_time_ms: float = 0.0, + top_k: Optional[int] = 20, +) -> List[KernelInfo]: + """ + Wrapper that supports: + - Excel (.xlsx/.xls) from NSys + - CSV cudaKernSummary + """ + ext = os.path.splitext(path)[1].lower() + if ext in [".xlsx", ".xls"]: + return read_nsys_kernels_from_excel(path, min_time_ms=min_time_ms, top_k=top_k) + else: + return read_nsys_kernels_from_csv(path, min_time_ms=min_time_ms, top_k=top_k) + + +# ------------------------------- +# Helpers for kernel names & filenames +# ------------------------------- + +def sanitize_for_filename(name: str) -> str: + """ + Turn a kernel name into a safe (short) filename component. + """ + base = re.sub(r"[^a-zA-Z0-9_]+", "_", name) + return base[:80] + + +def to_regex_pattern(full_name: str) -> str: + """ + Build a robust regex pattern from a full demangled kernel name. + + Heuristics: + - If it contains known stable substrings (fused_moe_kernel, FlashAttnFwdSm90, etc), + use those directly. + - Otherwise: + * strip the trailing '(...)' parameter list + * take the first ~120 chars + * escape for regex + """ + s = str(full_name) + + # 1) Known important substrings (you can extend this list) + known_keys = [ + "fused_moe_kernel", + "FlashAttnFwdSm90", + "FlashAttnFwdCombine", + "flash::FlashAttnFwdSm90", + "nvjet_tst_", + "nvjet_tss_", + "triton_per_fused_", + "triton_poi_fused_", + "triton_red_fused_", + "triton_tem_fused_", + "ncclDevKernel_", + "vllm::moe::", + ] + for key in known_keys: + if key in s: + return re.escape(key) + + # 2) Generic fallback: strip param-list and escape prefix + no_params = s.split("(", 1)[0].strip() # remove "(T1::Params)" etc. + if not no_params: + no_params = s.strip() + prefix = no_params[:120] + return re.escape(prefix) + + +# ------------------------------- +# Nsight Compute: run & parse metrics +# ------------------------------- + +def run_ncu_for_kernel( + kernel_name: str, + run_cmd: str, + out_dir: str, + workload: str, + index: int, + launch_skip: int = 0, + launch_count: int = 1, + dry_run: bool = False, +) -> str: + """ + Run NCU for a single kernel and return the metrics CSV path. + + We: + - filter by kernel name using regex (robust against template noise) + - request the MFU/MBU/time metrics directly via --metrics + - write both a .ncu-rep (for later GUI inspection) and a CSV (for the script) + """ + os.makedirs(out_dir, exist_ok=True) + + safe = sanitize_for_filename(kernel_name) + rep_base = f"ncu_{workload}_k{index:03d}_{safe}" + rep_path = os.path.join(out_dir, rep_base) + csv_base = rep_path + "_metrics" + + # Metrics we care about for MFU / MBU / duration + metrics = [ + "sm__throughput.avg.pct_of_peak_sustained_elapsed", + "gpu__dram_throughput.avg.pct_of_peak_sustained_elapsed", + "gpu__time_duration.sum", + ] + + pattern = to_regex_pattern(kernel_name) + regex_arg = f"regex:{pattern}" + + sections = [ + "LaunchStats", + "Occupancy", + "SpeedOfLight", + "SpeedOfLight_HierarchicalTensorRooflineChart", + ] + + cmd = ["ncu", "-f"] # -f to overwrite existing .ncu-rep + + for sec in sections: + cmd.extend(["--section", sec]) + + cmd.extend( + [ + "--kernel-name-base", + "demangled", + "--kernel-name", + regex_arg, + "--launch-skip", + str(launch_skip), + "--launch-count", + str(launch_count), + "--metrics", + ",".join(metrics), + "--csv", + "--log-file", + csv_base, + "-o", + rep_path, + "--", + ] + ) + + cmd.extend(shlex.split(run_cmd)) + + print(f"\n[NCU] workload={workload}, kernel #{index}: {kernel_name}") + print(f" Regex pattern: {regex_arg}") + print(" .ncu-rep :", rep_path + ".ncu-rep") + print(" Metrics CSV :", csv_base + ".csv") + print(" Command :", " ".join(shlex.quote(c) for c in cmd)) + + if dry_run: + return csv_base + ".csv" + + result = subprocess.run(cmd) + if result.returncode != 0: + print( + f"[WARN] NCU failed for kernel: {kernel_name} (exit={result.returncode})", + file=sys.stderr, + ) + return csv_base + ".csv" + + +def parse_metrics_csv(csv_path: str): + """ + Parse the metrics CSV produced by NCU with: + --metrics sm__throughput.avg.pct_of_peak_sustained_elapsed, + gpu__dram_throughput.avg.pct_of_peak_sustained_elapsed, + gpu__time_duration.sum + --csv --log-file + + Returns: + { + "sm_pct_of_peak": float or None, + "dram_pct_of_peak": float or None, + "duration_ns": float or None, + } + """ + metrics = { + "sm_pct_of_peak": None, + "dram_pct_of_peak": None, + "duration_ns": None, + } + + if not os.path.exists(csv_path): + print(f"[WARN] Metrics CSV not found: {csv_path}", file=sys.stderr) + return metrics + + with open(csv_path, newline="") as f: + # Skip comments and blank lines + lines = [ln for ln in f if ln.strip() and not ln.startswith("#")] + if not lines: + return metrics + + headers = [h.strip() for h in lines[0].strip().split(",")] + reader = csv.DictReader(lines[1:], fieldnames=headers) + + for row in reader: + metric_name = (row.get("Metric Name") or row.get("ID") or "").strip() + val_str = (row.get("Metric Value") or row.get("Value") or "").strip() + if not metric_name or not val_str: + continue + try: + val = float(val_str) + except ValueError: + continue + + mn = metric_name + if "sm__throughput" in mn and "pct_of_peak" in mn: + metrics["sm_pct_of_peak"] = val + elif "dram__throughput" in mn and "pct_of_peak" in mn: + metrics["dram_pct_of_peak"] = val + elif "gpu__time_duration" in mn and "sum" in mn: + metrics["duration_ns"] = val + + return metrics + + +# ------------------------------- +# Main +# ------------------------------- + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument( + "--workload", + required=True, + help="Name of workload: full | prefill | decode (used in output filenames).", + ) + ap.add_argument( + "--nsys-kern-csv", + required=True, + help="Path to NSys kernel summary (CSV or Excel, e.g. Nsys.xlsx).", + ) + ap.add_argument( + "--run-cmd", + required=True, + help=( + "Command to run under NCU, e.g. " + "'python3 run_bench.py --target_length 31744 --output_length 700'" + ), + ) + ap.add_argument( + "--out-dir", + required=True, + help="Output directory (e.g. outputs/full).", + ) + ap.add_argument( + "--min-time-ms", + type=float, + default=0.0, + help="Filter kernels with total time < this (ms).", + ) + ap.add_argument( + "--top-k", + type=int, + default=20, + help="Only profile top-K kernels by total time; 0 = all.", + ) + ap.add_argument( + "--dry-run", + action="store_true", + help="Print NCU commands without executing them.", + ) + args = ap.parse_args() + + top_k = None if args.top_k == 0 else args.top_k + kernels = read_nsys_kernels( + args.nsys_kern_csv, + min_time_ms=args.min_time_ms, + top_k=top_k, + ) + + if not kernels: + print( + "No kernels found from NSys summary with given filters.", + file=sys.stderr, + ) + sys.exit(1) + + print(f"Found {len(kernels)} hot kernels (workload={args.workload}):") + for k in kernels[:10]: + print( + f" {k.name} time={k.total_time_ms:.3f} ms pct={k.time_pct:.2f}%" + ) + + # Run NCU for each kernel + rep_paths = [] + for idx, kinfo in enumerate(kernels, start=1): + csv_path = run_ncu_for_kernel( + kernel_name=kinfo.name, + run_cmd=args.run_cmd, + out_dir=args.out_dir, + workload=args.workload, + index=idx, + launch_skip=0, # tweak if you want to skip warmup launches + launch_count=1, + dry_run=args.dry_run, + ) + rep_paths.append((kinfo, csv_path)) + + if args.dry_run: + print("[DRY-RUN] Skipping aggregation.") + return + + summary_rows = [] + for kinfo, csv_path in rep_paths: + m = parse_metrics_csv(csv_path) + row = { + "workload": args.workload, + "kernel_name": kinfo.name, + "nsys_time_ms": kinfo.total_time_ms, + "nsys_time_pct": kinfo.time_pct, + "ncu_sm_pct_of_peak": m["sm_pct_of_peak"], + "ncu_dram_pct_of_peak": m["dram_pct_of_peak"], + "ncu_duration_ns": m["duration_ns"], + } + summary_rows.append(row) + + out_csv = os.path.join( + args.out_dir, f"ncu_kernel_summary_{args.workload}.csv" + ) + if summary_rows: + fieldnames = list(summary_rows[0].keys()) + os.makedirs(args.out_dir, exist_ok=True) + with open(out_csv, "w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + for r in summary_rows: + writer.writerow(r) + print(f"[OK] Wrote per-kernel summary to {out_csv}") + else: + print("[WARN] No summary rows collected.") + + +if __name__ == "__main__": + main() diff --git a/tools/hardware_utilization_toolkit/calculate_mfu_mbu.py b/tools/hardware_utilization_toolkit/calculate_mfu_mbu.py new file mode 100644 index 000000000000..d3cee10787fe --- /dev/null +++ b/tools/hardware_utilization_toolkit/calculate_mfu_mbu.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python3 +import argparse +import csv +from collections import defaultdict +from typing import Dict, List + + +def load_kernel_summary(path: str) -> List[Dict]: + rows = [] + with open(path, newline="") as f: + reader = csv.DictReader(f) + for row in reader: + rows.append(row) + return rows + + +def compute_mfu_mbu(rows: List[Dict]) -> Dict[str, float]: + """ + rows: from ncu_kernel_summary_*.csv for a single workload. + Uses NSys time fraction + NCU MFU/MBU per kernel. + """ + # Sum time across all kernels (ms) + total_time_ms = 0.0 + for r in rows: + t = float(r["nsys_time_ms"]) + total_time_ms += t + + if total_time_ms == 0.0: + return {"MFU": 0.0, "MBU": 0.0} + + mfu_num = 0.0 + mbu_num = 0.0 + + for r in rows: + t = float(r["nsys_time_ms"]) + sm = r["ncu_sm_pct_of_peak"] + dram = r["ncu_dram_pct_of_peak"] + if sm == "" or dram == "": + continue + sm = float(sm) / 100.0 + dram = float(dram) / 100.0 + + weight = t / total_time_ms + mfu_num += sm * weight + mbu_num += dram * weight + + return {"MFU": mfu_num, "MBU": mbu_num} + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--full", help="ncu_kernel_summary_full.csv") + ap.add_argument("--prefill", help="ncu_kernel_summary_prefill.csv") + ap.add_argument("--decode", help="ncu_kernel_summary_decode.csv") + args = ap.parse_args() + + workloads = {} + if args.full: + workloads["full"] = args.full + if args.prefill: + workloads["prefill"] = args.prefill + if args.decode: + workloads["decode"] = args.decode + + if not workloads: + print("No workloads provided.") + return + + print("Workload,MFU,MBU") + for name, path in workloads.items(): + rows = load_kernel_summary(path) + res = compute_mfu_mbu(rows) + print(f"{name},{res['MFU']:.4f},{res['MBU']:.4f}") + + +if __name__ == "__main__": + main() diff --git a/tools/hardware_utilization_toolkit/run_bench.py b/tools/hardware_utilization_toolkit/run_bench.py new file mode 100644 index 000000000000..433406f8f94b --- /dev/null +++ b/tools/hardware_utilization_toolkit/run_bench.py @@ -0,0 +1,182 @@ +import argparse +import asyncio +import os +from transformers import AutoTokenizer + +from vllm import AsyncLLMEngine, SamplingParams +from vllm.entrypoints.api_server import AsyncEngineArgs +from typing import Optional +from vllm.config import CompilationConfig + +################################ Helper Function ################################ +def make_compilation_config( + cuda_graph_mode: str = "PIECEWISE", + compile_sizes_list: Optional[str] = None +) -> CompilationConfig: + """ + Build a CompilationConfig from provided values. + """ + kwargs = {"cudagraph_mode": cuda_graph_mode} + + if compile_sizes_list: + kwargs["compile_sizes"] = [int(size) for size in compile_sizes_list.split(",")] + + return CompilationConfig(**kwargs) + +################################ Build Customized N Length Prompt################################ + +# Parse CLI args +parser = argparse.ArgumentParser(description="Run vLLM with customized length prompts.") +parser.add_argument('--target_length', type=int, default=31744, help='Target prompt token length') +parser.add_argument('--output_length', type=int, default=700, help='Output token length') +args = parser.parse_args() + +# Set desired prompt token length here (e.g., 6k, 8k) +TARGET_PROMPT_TOKEN_LENGTH = args.target_length +OUTPUT_TOKEN_LENGTH = args.output_length + +#Get rid of nile_triton + +TOKENIZER_PATH = os.getenv("TOKENIZER_PATH", "/data/neolite_bf16_ckpt/") +MODEL_PATH = os.getenv("MODEL_PATH", "/data/neolite_bf16_ckpt/") +GPU_VLLM_BLOCK_SIZE = int(os.getenv("GPU_VLLM_BLOCK_SIZE", 16)) +KV_CACHE_DTYPE = os.getenv("KV_CACHE_DTYPE", "auto") +ENABLE_PREFIX_CACHING = os.getenv("ENABLE_PREFIX_CACHING", "False").lower() in ["true", "t", "1"] +TENSOR_PARALLEL_SIZE = int(os.getenv("TENSOR_PARALLEL_SIZE", 8)) +VLLM_SWAP_SPACE = 4 +VLLM_GPU_MEMORY_UTILIZATION = float(os.getenv("GPU_VLLM_GPU_MEMORY_UTILIZATION", "0.9")) +MAX_NUM_BATCHED_TOKENS = int(os.getenv("MAX_NUM_BATCHED_TOKENS", 131072)) +MAX_BATCH_SIZE = int(os.getenv("MAX_BATCH_SIZE", "16")) + + +# Load tokenizer +tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, use_fast=True) + +# Base: informative technical content +intro = ( + "Title: The State of Artificial Intelligence in 2025\n\n" + "Abstract: This report provides an overview of major advances in artificial intelligence as of the year 2025. " + "It covers developments in natural language models, reinforcement learning, multimodal systems, and ethical challenges. " + "We also provide a critical look at the deployment of large language models in commercial applications and their implications on society.\n\n" +) + +# Simulated long-form content (use coherent technical paragraphs) +base_paragraph = ( + "Large language models (LLMs) have demonstrated remarkable performance in a variety of NLP tasks including summarization, translation, and question answering. " + "Techniques like retrieval-augmented generation (RAG), parameter-efficient fine-tuning, and speculative decoding are widely adopted. " + "At the same time, challenges remain in model alignment, hallucination control, and latency under constrained inference budgets. " + "In production environments, inference optimizations such as KV cache reuse, continuous batching, and CUDA Graph integration are critical for cost-effective deployment. " + "Moreover, foundation models are increasingly evaluated not just on accuracy but also on robustness, fairness, and interpretability. " +) + +# Compose long prompt +prompt_text = intro +while len(tokenizer(prompt_text).input_ids) < TARGET_PROMPT_TOKEN_LENGTH - 100: + prompt_text += base_paragraph + +# Add a final instruction for generation +prompt_text += ( + "\n\n---\n\n" + "Based on the content above, please summarize the key challenges facing large-scale model inference in production, " + "and propose potential solutions to reduce latency and cost while preserving model quality." +) + +# Trim to exactly TARGET_PROMPT_TOKEN_LENGTH tokens +input_ids = tokenizer(prompt_text).input_ids[:TARGET_PROMPT_TOKEN_LENGTH] +final_prompt = tokenizer.decode(input_ids, skip_special_tokens=True) + +# Inject into vLLM script +prompts = [final_prompt, final_prompt] +request_ids = ["0", "1"] + +# Optional preview +print(f"✅ Prompt token count: {len(tokenizer(final_prompt).input_ids)}") +print(f"📝 Prompt preview:\n{final_prompt[:1000]}") + +CUDA_GRAPH_MODE = str(os.environ.get("CUDA_GRAPH_MODE", "PIECEWISE")) +COMPILE_SIZES_LIST = os.environ.get("COMPILE_SIZES_LIST", None) +USE_DUMMY_WEIGHT = str(os.environ.get("USE_DUMMY_WEIGHT", "False")).lower() in ["true", "t", "1"] +USE_NEOLITE_MODEL = str(os.environ.get("USE_NEOLITE_MODEL", "False")).lower() in ["true", "t", "1"] + +kwargs = dict( + model=MODEL_PATH, + tokenizer=TOKENIZER_PATH, + tokenizer_mode="auto", + seed=0, + max_num_seqs=MAX_BATCH_SIZE, + max_model_len=131072, + max_num_batched_tokens=MAX_NUM_BATCHED_TOKENS, + gpu_memory_utilization=VLLM_GPU_MEMORY_UTILIZATION, + swap_space=VLLM_SWAP_SPACE, + tensor_parallel_size=TENSOR_PARALLEL_SIZE, + trust_remote_code=False, + disable_log_stats=False, + dtype="auto", + enable_prefix_caching=ENABLE_PREFIX_CACHING, + kv_cache_dtype=KV_CACHE_DTYPE, + block_size=GPU_VLLM_BLOCK_SIZE, + enable_chunked_prefill=True, + collect_time_per_step=False, + compilation_config=make_compilation_config(cuda_graph_mode=CUDA_GRAPH_MODE, compile_sizes_list=COMPILE_SIZES_LIST), +) +if USE_DUMMY_WEIGHT: + kwargs["load_format"] = "dummy" + +if USE_NEOLITE_MODEL: + kwargs["hf_overrides"] = {"architectures": ["NeoLiteMixtralForCausalLM"], "kv_lora_rank": None} + +engine_args = AsyncEngineArgs( + **kwargs, +) +print(f"engine_args is {engine_args}") + + +async def generate_all_requests(prompts, request_ids): + print("🚀 Sending prompts to vLLM engine...") + print("Request_Id: ", request_ids) + async_llm_engine = AsyncLLMEngine.from_engine_args(engine_args) + sampling_params = SamplingParams( + top_k=1, + max_tokens=OUTPUT_TOKEN_LENGTH, + stop_token_ids=[int(token) for token in os.getenv("STOP_TOKEN_IDS", "199999,200002").split(",")] + ) + + if USE_NEOLITE_MODEL: + print("Updating sampling params for NeoLite...") + sampling_params = SamplingParams(temperature=0.6, + top_p=1.0, # same as default value in SamplingParams + top_k=-1, + min_p=0.0, # same as default value in SamplingParams + stop="<|im_end|>", + stop_token_ids=[200002], + n=1, # same as default value in SamplingParams + ignore_eos=False, # same as default value in SamplingParams + guided_decoding=None, # same as default value in SamplingParams + max_tokens=OUTPUT_TOKEN_LENGTH, # temp value, update after final optimization + ) + + print(f"sampling params is {sampling_params}") + + async def process_single_prompt(prompt, request_id): + try: + generated_text = "" + async for output in async_llm_engine.generate(prompt, sampling_params, request_id): + generated_text = output.outputs[0].text + print(f"\n{'='*80}") + print(f"🆔 Request ID: {request_id}") + print(f"\n📥 Prompt preview (first 100 chars):\n{prompt[:100]}...") + print(f"\n📤 Generated Text:\n{generated_text}") + print(f"{'='*80}\n") + except Exception as e: + print(f"❌ Error for prompt {request_id}: {str(e)}") + + # Create tasks for all prompts + tasks = [process_single_prompt(prompt, req_id) for prompt, req_id in zip(prompts, request_ids)] + + # Run all tasks concurrently + await asyncio.gather(*tasks) + + +# Run the async function +asyncio.run(generate_all_requests(prompts, request_ids)) + \ No newline at end of file