-
Notifications
You must be signed in to change notification settings - Fork 15
09. Reproducer Guide
This guide provides comprehensive documentation for TritonParse's reproducer system, which generates standalone Python scripts to reproduce specific kernel executions from trace files.
A reproducer is a self-contained Python script that:
- Recreates the exact execution environment of a kernel launch
- Reconstructs input tensors using various strategies
- Can be run independently without the original codebase
- Enables debugging, benchmarking, and sharing of kernel issues
| Use Case | Description |
|---|---|
| Bug Isolation | Extract a single kernel execution to debug in isolation |
| Performance Analysis | Benchmark specific kernel configurations |
| Issue Sharing | Share reproducible test cases with collaborators |
| Regression Testing | Compare kernel behavior across versions |
| Documentation | Create executable examples of kernel usage |
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
│ Generate Trace │────▶│ Parse Trace │────▶│ Reproduce │
│ (with launch) │ │ (unified_parse)│ │ (reproduce) │
└─────────────────┘ └─────────────────┘ └─────────────────┘
│ │ │
▼ ▼ ▼
*.ndjson logs *.ndjson.gz files repro_*.py scripts
# Generate reproducer for a specific launch event
tritonparseoss reproduce ./parsed_output/trace.ndjson.gz --line 1 --out-dir repro_output
# Using kernel name instead of line index
tritonparseoss reproduce ./trace.ndjson.gz --kernel matmul_kernel --out-dir repro_output
# With custom template
tritonparseoss reproduce ./trace.ndjson.gz --line 1 --template tritonbench --out-dir bench_outputfrom tritonparse.reproducer.orchestrator import reproduce
result = reproduce(
input_path="./parsed_output/trace.ndjson.gz",
line_index=1, # 0-based index (0 = compilation, 1+ = launches)
out_dir="./repro_output",
template="example", # Built-in template
)
print(f"Script: {result['repro_script']}")
print(f"Context: {result['repro_context']}")repro_output/<kernel_name>/
├── repro_<timestamp>.py # Standalone executable script
├── repro_context_<timestamp>.json # Kernel metadata, args, and launch info
└── <hash>.bin # Tensor blob files (if enabled during tracing)
A trace file contains multiple events in chronological order:
| Line Index | Event Type | Description |
|---|---|---|
| 0 | compilation |
Kernel compilation metadata and IR |
| 1 | launch |
First kernel execution |
| 2 | launch |
Second kernel execution |
| ... | ... | ... |
| N | launch_diff |
Summary of launch variations |
When generating a reproducer:
-
line_index=0targets the compilation event (rarely used for reproduction) -
line_index=1targets the first launch event (most common) - Higher indices target subsequent launches
Internally, the reproducer builds a ContextBundle containing all information needed:
@dataclass
class ContextBundle:
kernel_info: KernelInfo # Function name, file path, source code
compile: Dict[str, Any] # num_warps, num_stages, arch, backend
launch: Dict[str, Any] # grid, kwargs
args: Dict[str, Any] # All arguments (scalars + tensors)
tensor_args: Dict[str, Any] # Tensor-specific information
raw_launch_event: Dict # Original launch event data
raw_comp_event: Dict # Original compilation event dataControls how the kernel function is imported in the generated script:
| Mode | Description | When to Use |
|---|---|---|
DEFAULT |
Import from original source file | Standard case, original file accessible |
COPY |
Embed kernel source in reproducer | Share without original codebase |
OVERRIDE_TTIR |
Use TTIR with monkeypatch | Debug specific IR versions |
CLI Usage:
tritonparseoss reproduce trace.ndjson.gz --line 1 --kernel-import copy
tritonparseoss reproduce trace.ndjson.gz --line 1 --kernel-import override-ttirPython Usage:
from tritonparse.reproducer.orchestrator import reproduce
from tritonparse.reproducer.types import KernelImportMode
result = reproduce(
input_path="./trace.ndjson.gz",
line_index=1,
out_dir="./repro",
kernel_import=KernelImportMode.COPY,
)The reproducer supports three tensor reconstruction strategies, applied in priority order:
Exact tensor data saved during tracing.
Enable during tracing:
tritonparse.structured_logging.init(
"./logs/",
enable_trace_launch=True,
enable_tensor_blob_storage=True, # Save actual tensor data
tensor_storage_quota=10 * 1024**3, # 10GB limit
)Or with environment variables:
export TRITONPARSE_SAVE_TENSOR_BLOBS=1
export TRITONPARSE_TENSOR_STORAGE_QUOTA=10737418240 # 10GBBehavior:
- Saves
.binfiles alongside traces - Reproducer loads exact values via
load_tensor() - Best for numerical accuracy debugging
Uses saved statistics to generate similar data.
Enable during tracing:
tritonparse.structured_logging.init(
"./logs/",
enable_trace_launch=True,
enable_more_tensor_information=True, # Save statistics
)Statistics saved:
-
mean- Average value -
std- Standard deviation -
min- Minimum value -
max- Maximum value
Reconstruction logic:
# For floating point tensors
tensor = torch.randn(shape, dtype=torch.float32, device=device) * std + mean
tensor = torch.clamp(tensor, min=min_val, max=max_val)
tensor = tensor.to(target_dtype)
# For integer tensors
tensor = torch.round(tensor).to(target_dtype)Basic random generation when no statistics available.
# Floating point: random values
torch.empty(shape, dtype=dtype, device=device).random_()
# Integer: random integers
torch.empty(shape, dtype=dtype, device=device).random_()
# Complex: random real + imaginary
torch.complex(real_part, imag_part)The reproducer preserves non-contiguous tensor layouts:
# If tensor has custom stride or storage_offset
strided_view = storage_tensor.as_strided(
size=shape,
stride=stride,
storage_offset=storage_offset
)This ensures:
- Transposed tensors work correctly
- Sliced tensors maintain their layout
- Memory access patterns match original
| Template | Description | Use Case |
|---|---|---|
example |
Basic standalone script | General debugging |
tritonbench |
TritonBench-compatible operator | Performance benchmarking |
Templates use placeholders that get replaced during generation:
| Placeholder | Description |
|---|---|
{{KERNEL_IMPORT_PLACEHOLDER}} |
Import statements for the kernel |
{{KERNEL_INVOCATION_PLACEHOLDER}} |
Kernel launch code with arguments |
{{KERNEL_SYSPATH_PLACEHOLDER}} |
sys.path setup for imports |
{{JSON_FILE_NAME_PLACEHOLDER}} |
Context JSON filename |
{{UTILITY_FUNCTIONS_PLACEHOLDER}} |
Helper functions (create_args_from_json, etc.) |
{{IR_OVERRIDE_SETUP_PLACEHOLDER}} |
TTIR override setup (for OVERRIDE_TTIR mode) |
Step 1: Create template file
# my_template.py
"""Custom reproducer template for my use case"""
import torch
import logging
logger = logging.getLogger(__name__)
# {{IR_OVERRIDE_SETUP_PLACEHOLDER}}
# {{KERNEL_SYSPATH_PLACEHOLDER}}
# {{KERNEL_IMPORT_PLACEHOLDER}}
# {{UTILITY_FUNCTIONS_PLACEHOLDER}}
def run_kernel():
"""Execute the reproduced kernel."""
from pathlib import Path
script_dir = Path(__file__).resolve().parent
json_file = script_dir / "{{JSON_FILE_NAME_PLACEHOLDER}}"
grid, args_dict = create_args_from_json_file(str(json_file))
print("=" * 60)
print("CUSTOM REPRODUCER")
print("=" * 60)
print(f"Grid: {grid}")
for name, arg in args_dict.items():
if torch.is_tensor(arg):
print(f" {name}: tensor {arg.shape} {arg.dtype}")
else:
print(f" {name}: {arg}")
# {{KERNEL_INVOCATION_PLACEHOLDER}}
torch.cuda.synchronize()
print("Execution complete!")
if __name__ == "__main__":
run_kernel()Step 2: Use the template
tritonparseoss reproduce trace.ndjson.gz --line 1 --template /path/to/my_template.pyresult = reproduce(
input_path="./trace.ndjson.gz",
line_index=1,
template="/path/to/my_template.py",
out_dir="./repro",
)The built-in tritonbench template generates a TritonBench-compatible operator:
tritonparseoss reproduce trace.ndjson.gz --line 1 --template tritonbench --out-dir benchThis creates a script compatible with:
python -m tritonbench --op <operator_name> --mode latencyInstead of specifying line_index, you can find kernels by name:
from tritonparse.reproducer.orchestrator import reproduce
result = reproduce(
input_path="./trace.ndjson.gz",
kernel_name="matmul_kernel", # Find by exact name
launch_id=0, # Which launch instance (0 = first)
out_dir="./repro",
)CLI equivalent:
tritonparseoss reproduce trace.ndjson.gz --kernel matmul_kernel --launch-id 0For advanced customization, implement a custom replacer:
from tritonparse.reproducer.placeholder_replacer import PlaceholderReplacer, DefaultPlaceholderReplacer
from tritonparse.reproducer.orchestrator import reproduce
class MyReplacer(DefaultPlaceholderReplacer):
def replace(self, template_code, context_bundle, **kwargs):
# Call parent for standard replacements
code = super().replace(template_code, context_bundle, **kwargs)
# Add custom modifications
code = code.replace("{{MY_CUSTOM_PLACEHOLDER}}", "my_value")
return code
result = reproduce(
input_path="./trace.ndjson.gz",
line_index=1,
out_dir="./repro",
replacer=MyReplacer(),
)For projects using triton_kernels library:
# The reproducer automatically handles these types if triton_kernels is installed:
from triton_kernels.tensor import Tensor, Storage, StridedLayoutSupported types:
triton_kernels.tensor.Tensortriton_kernels.tensor.StorageStridedLayout
If not installed, you'll see:
RuntimeError: Optional dependency 'triton_kernels.tensor' is not installed
Solution:
pip install triton_kernelstritonparseoss reproduce <input_file> [options]Arguments:
| Argument | Description |
|---|---|
<input_file> |
Path to trace file (.ndjson or .ndjson.gz) |
Options:
| Option | Description | Default |
|---|---|---|
--line <N> |
Line index (0-based) of launch event | 0 |
--kernel <name> |
Find kernel by exact name | - |
--launch-id <N> |
Launch instance when using --kernel | 0 |
--out-dir <path> |
Output directory | repro_output/<kernel>/ |
| `--template <name | path>` | Template name or path to custom template |
--kernel-import <mode> |
Import mode: default, copy, override-ttir
|
default |
Examples:
# Basic usage
tritonparseoss reproduce ./trace.ndjson.gz --line 1 --out-dir ./repro
# Find by kernel name
tritonparseoss reproduce ./trace.ndjson.gz --kernel add_kernel --out-dir ./repro
# Use tritonbench template
tritonparseoss reproduce ./trace.ndjson.gz --line 1 --template tritonbench
# Embed kernel source
tritonparseoss reproduce ./trace.ndjson.gz --line 1 --kernel-import copy
# Multiple options
tritonparseoss reproduce ./trace.ndjson.gz \
--kernel matmul_kernel \
--launch-id 2 \
--template /path/to/custom_template.py \
--kernel-import copy \
--out-dir ./my_reproQuery kernel information from traces (useful before reproducing):
# List all kernels
tritonparseoss info ./trace.ndjson.gz
# Query specific kernel
tritonparseoss info ./trace.ndjson.gz --kernel matmul_kernel
# Show argument details
tritonparseoss info ./trace.ndjson.gz --kernel matmul_kernel --args-listExtract a problematic kernel and debug in isolation:
# 1. Generate reproducer
tritonparseoss reproduce trace.ndjson.gz --line 42 --out-dir bug_repro
# 2. Run reproducer
cd bug_repro/<kernel_name>
python repro_*.py
# 3. Modify and debug
# Edit the generated script to add debuggingCreate benchmarkable kernel scripts:
# Generate tritonbench-compatible reproducer
tritonparseoss reproduce trace.ndjson.gz --line 1 --template tritonbench --out-dir bench
# Run with tritonbench
python -m tritonbench --op bench/<kernel_name>/repro_*.py --mode latencyOr add manual timing:
# In generated script, add:
import time
start = time.perf_counter()
for _ in range(100):
# {{KERNEL_INVOCATION_PLACEHOLDER}}
torch.cuda.synchronize()
elapsed = time.perf_counter() - start
print(f"Average time: {elapsed / 100 * 1000:.3f} ms")Compare behavior across versions:
# Generate reproducers for two traces
tritonparseoss reproduce trace_v1.ndjson.gz --line 1 --out-dir v1
tritonparseoss reproduce trace_v2.ndjson.gz --line 1 --out-dir v2
# Compare outputs
python v1/<kernel>/repro_*.py > v1_output.txt
python v2/<kernel>/repro_*.py > v2_output.txt
diff v1_output.txt v2_output.txtCombine reproducer workflow with File Diff for comprehensive analysis:
- Generate traces from both versions
- Use File Diff View to identify differing kernels
- Generate reproducers for specific kernels
- Debug/benchmark in isolation
- Re-trace after fixes and compare again
# After identifying kernel differences in File Diff View
result = reproduce(
input_path="./trace_v1.ndjson.gz",
kernel_name="problematic_kernel",
out_dir="./debug_v1",
)
# Modify and test the reproducer
# ...
# After fix, re-trace and compare in File Diff ViewCreate portable test cases:
# Use COPY mode to embed kernel source
tritonparseoss reproduce trace.ndjson.gz \
--line 1 \
--kernel-import copy \
--out-dir shareable_repro
# The generated script is self-contained
# Share the entire shareable_repro/<kernel>/ directoryQ: "Event at index N is not a launch event"
The specified line_index points to a non-launch event (e.g., compilation or launch_diff).
Solution: Use --line 1 or higher for launch events. Use tritonparseoss info to list events.
Q: "Could not find compilation hash in launch event"
The trace may be incomplete or corrupted.
Solution: Re-generate the trace with proper initialization.
Q: "Optional dependency 'triton_kernels.tensor' is not installed"
The trace uses custom tensor types from triton_kernels.
Solution: pip install triton_kernels
Q: Reproducer runs but produces different results
Tensor data reconstruction may not perfectly match original.
Solution:
- Enable
enable_tensor_blob_storage=Trueduring tracing for exact data - Or enable
enable_more_tensor_information=Truefor better approximation
Q: "Could not resolve kernel file path"
The kernel source file path cannot be determined.
Solution: Use --kernel-import copy to embed the source directly.
Enable verbose logging:
export TRITONPARSE_DEBUG=1
tritonparseoss reproduce trace.ndjson.gz --line 1Or in Python:
import logging
logging.getLogger("tritonparse").setLevel(logging.DEBUG)- Usage Guide - Reproducer Section - Quick reference
- Python API Reference - Full API documentation
- Environment Variables - Tensor storage configuration
- FAQ - Reproducer - Common questions