-
Notifications
You must be signed in to change notification settings - Fork 15
07. Environment Variables Reference
This page provides a complete reference for all environment variables supported by TritonParse.
π‘ Looking for Python API? See Python API Reference for complete documentation of
init(),unified_parse(),TritonParseManager, and other Python interfaces.
π Scope: All environment variables documented here affect the
tritonparse.structured_loggingmodule, which handles trace generation during Triton kernel compilation and execution.
In OSS (open-source) environments, TritonParse is not automatically enabled. Even when using environment variables, you must call init() in your Python code to activate tracing:
import tritonparse.structured_logging
# Just call init() - environment variables will be automatically applied
tritonparse.structured_logging.init()How it works:
- Environment variables are read when the module is imported
- Calling
init()without arguments activates tracing using the environment variable values - If you pass arguments to
init(), they take precedence over environment variables
Complete example with environment variables:
# Set environment variables
export TRITON_TRACE="./logs/"
export TRITON_TRACE_LAUNCH="1"
export TRITONPARSE_MORE_TENSOR_INFORMATION="1"import tritonparse.structured_logging
# Activate tracing - environment variables are automatically used
tritonparse.structured_logging.init()
# Your kernel code here...| Variable | Purpose | Default | Category |
|---|---|---|---|
TRITON_TRACE |
Trace output directory | /logs/ |
Trace Generation |
TRITON_TRACE_LAUNCH |
Enable launch tracing | Off | Trace Generation |
TORCHINDUCTOR_RUN_JIT_POST_COMPILE_HOOK |
Enable inductor hook | Off | Trace Generation |
TRITONPARSE_MORE_TENSOR_INFORMATION |
Collect tensor stats | Off | Tensor Info |
TRITONPARSE_SAVE_TENSOR_BLOBS |
Save tensor blobs | Off | Tensor Info |
TRITONPARSE_TENSOR_SIZE_LIMIT |
Max tensor blob size | 10GB | Tensor Info |
TRITONPARSE_TENSOR_STORAGE_QUOTA |
Total storage quota | 100GB | Tensor Info |
TRITONPARSE_DEBUG |
Debug logging | Off | Debug |
TRITONPARSE_KERNEL_ALLOWLIST |
Kernel filter patterns | All | Debug |
TRITON_TRACE_GZIP |
Gzip compression | Off | Performance |
TRITONPARSE_DUMP_SASS |
SASS dump | Off | Performance |
TRITON_FULL_PYTHON_SOURCE |
Full source extraction | Off | Source |
TRITON_MAX_SOURCE_SIZE |
Max source file size | 10MB | Source |
TEST_KEEP_OUTPUT |
Keep test outputs | Off | Testing |
TORCHINDUCTOR_FX_GRAPH_CACHE |
PyTorch FX cache | On | Testing |
These variables control how TritonParse captures compilation and launch events.
Description: Directory to store raw trace files.
| Property | Value |
|---|---|
| Values | Any valid directory path |
| Default |
/logs/ (if exists and writable), otherwise disabled |
| Related | Also set via tritonparse.structured_logging.init(trace_folder=...)
|
Example:
export TRITON_TRACE="./my_logs/"
python your_script.pyDescription: Enable kernel launch event tracing. When enabled, captures runtime launch parameters including grid dimensions, tensor arguments, and execution metadata.
| Property | Value |
|---|---|
| Values |
"1", "true", "True" to enable |
| Default | Disabled |
| Related | Also set via tritonparse.structured_logging.init(enable_trace_launch=True)
|
Example:
export TRITON_TRACE_LAUNCH="1"π‘ Tip: Launch tracing is essential for reproducer generation and launch diff analysis.
Description: Required for tracing kernel launches from TorchInductor (torch.compile). This hook allows TritonParse to intercept and log launch metadata for Inductor-compiled kernels.
| Property | Value |
|---|---|
| Values |
"1", "true", "True" to enable |
| Default | Disabled (auto-enabled if TRITON_TRACE_LAUNCH=1) |
| When to use | When using torch.compile() and need launch tracing |
Example:
# Required for torch.compile kernels
export TORCHINDUCTOR_RUN_JIT_POST_COMPILE_HOOK="1"
export TRITON_TRACE_LAUNCH="1"# Or via Python
import os
os.environ["TORCHINDUCTOR_RUN_JIT_POST_COMPILE_HOOK"] = "1"
import tritonparse.structured_logging
tritonparse.structured_logging.init("./logs/", enable_trace_launch=True)
β οΈ Important: For native Triton kernels (@triton.jit), this variable is not needed. It's only required fortorch.compile/ TorchInductor generated kernels.
These variables control how TritonParse collects and stores tensor data for debugging and reproducer generation.
Description: Collect detailed tensor statistics including min, max, mean, and standard deviation values. This information is useful for reproducer generation and debugging.
| Property | Value |
|---|---|
| Values |
"1", "true", "True" to enable |
| Default | Disabled |
| Related | Also set via tritonparse.structured_logging.init(enable_more_tensor_information=True)
|
Collected statistics:
-
min: Minimum value in tensor -
max: Maximum value in tensor -
mean: Mean value of tensor elements -
std: Standard deviation of tensor elements
Example:
export TRITONPARSE_MORE_TENSOR_INFORMATION="1"# Or via Python
tritonparse.structured_logging.init(
"./logs/",
enable_trace_launch=True,
enable_more_tensor_information=True,
)π‘ Use case: When generating reproducers, these statistics allow for better tensor reconstruction that approximates the original data distribution.
Description: Save actual tensor data as blob files. Enables highest-fidelity reproducer generation by preserving the exact tensor values.
| Property | Value |
|---|---|
| Values |
"1", "true", "True" to enable |
| Default | Disabled |
| Storage location | <trace_folder>/saved_tensors/ |
| Related | Also set via tritonparse.structured_logging.init(enable_tensor_blob_storage=True)
|
Example:
export TRITONPARSE_SAVE_TENSOR_BLOBS="1"
β οΈ Warning: Can consume significant disk space. UseTRITONPARSE_TENSOR_STORAGE_QUOTAto limit storage.
Description: Maximum size for individual tensor blobs. Tensors larger than this limit will be skipped during blob storage.
| Property | Value |
|---|---|
| Values | Integer (bytes) |
| Default | 10GB (10 * 1024 * 1024 * 1024) |
Example:
# Limit to 1GB per tensor
export TRITONPARSE_TENSOR_SIZE_LIMIT=$((1 * 1024 * 1024 * 1024))Description: Total storage quota for tensor blobs in a single run. Once exceeded, blob storage is disabled for the remainder of the run.
| Property | Value |
|---|---|
| Values | Integer (bytes) |
| Default | 100GB (100 * 1024 * 1024 * 1024) |
| Related | Also set via tritonparse.structured_logging.init(tensor_storage_quota=...)
|
Example:
# Limit total storage to 50GB
export TRITONPARSE_TENSOR_STORAGE_QUOTA=$((50 * 1024 * 1024 * 1024))These variables help with debugging TritonParse itself and filtering trace output.
Description: Enable debug logging for TritonParse internal operations.
| Property | Value |
|---|---|
| Values |
"1", "true", "True" to enable |
| Default | Disabled |
Example:
export TRITONPARSE_DEBUG="1"
python your_script.pyπ‘ Use case: Helpful when troubleshooting trace generation issues or understanding TritonParse behavior.
Description: Filter which kernels to trace using fnmatch patterns. Only kernels matching at least one pattern will be traced.
| Property | Value |
|---|---|
| Values | Comma-separated fnmatch patterns |
| Default | All kernels traced |
Pattern syntax:
-
*matches any characters -
?matches a single character -
[seq]matches any character in seq
Example:
# Only trace matmul and attention kernels
export TRITONPARSE_KERNEL_ALLOWLIST="matmul*,*attention*,flash_*"π‘ Use case: Reduces trace file size and processing time when debugging specific kernels.
These variables control how Python source code is extracted and stored in traces.
Description: Extract the entire Python source file instead of just the kernel function definition.
| Property | Value |
|---|---|
| Values |
"1", "true", "True" to enable |
| Default | Function-only extraction |
Example:
export TRITON_FULL_PYTHON_SOURCE="1"π‘ Use case: Useful when you need full context including imports and helper functions for debugging.
Description: Maximum file size for full Python source extraction. Files larger than this limit will fall back to function-only extraction.
| Property | Value |
|---|---|
| Values | Integer (bytes) |
| Default | 10MB (10 * 1024 * 1024) |
Example:
# Allow up to 20MB source files
export TRITON_MAX_SOURCE_SIZE=$((20 * 1024 * 1024))These variables affect trace file size and compilation performance.
Description: Enable gzip compression for individual trace log entries. Each log entry is compressed as a separate gzip member, allowing for incremental writing and standard gzip decompression.
| Property | Value |
|---|---|
| Values |
"1", "true", "True" to enable |
| Default | Disabled |
| File extension | Changes from .ndjson to .bin.ndjson (gzip format with .bin.ndjson naming for consistency) |
Example:
export TRITON_TRACE_GZIP="1"π‘ Note: The output file is in gzip format but uses
.bin.ndjsonextension for naming consistency. Standard gzip tools can decompress it.
Description: Dump NVIDIA SASS (Shader Assembly) from CUBIN files using nvdisasm.
| Property | Value |
|---|---|
| Values |
"1", "true", "True" to enable |
| Default | Disabled |
| Requires | NVIDIA CUDA toolkit with nvdisasm
|
| Related | Also set via tritonparse.structured_logging.init(enable_sass_dump=True)
|
Example:
export TRITONPARSE_DUMP_SASS="1"
β οΈ Warning: Significantly slows down compilation. Only enable when needed for low-level analysis.
These variables are primarily used during development and testing.
Description: Preserve temporary output directories from tests instead of cleaning them up.
| Property | Value |
|---|---|
| Values |
"1", "true", "True" to enable |
| Default | Disabled (cleanup on exit) |
Example:
export TEST_KEEP_OUTPUT="1"
python -m unittest tests.test_tritonparse -vDescription: (PyTorch variable) Control FX graph caching. Disable to ensure fresh compilation for each run.
| Property | Value |
|---|---|
| Values |
"0" to disable |
| Default | Enabled (caching on) |
Example:
# Disable cache to ensure fresh compilation
export TORCHINDUCTOR_FX_GRAPH_CACHE="0"
python your_test.pyπ‘ Use case: Essential during testing to ensure kernels are recompiled and traces are generated.
export TRITON_TRACE="./logs/"
python your_triton_script.pyexport TRITON_TRACE="./logs/"
export TRITON_TRACE_LAUNCH="1"
export TORCHINDUCTOR_RUN_JIT_POST_COMPILE_HOOK="1"
export TRITONPARSE_MORE_TENSOR_INFORMATION="1"
python your_torch_compile_script.pyexport TRITON_TRACE="./logs/"
export TRITON_TRACE_LAUNCH="1"
export TORCHINDUCTOR_RUN_JIT_POST_COMPILE_HOOK="1"
export TRITONPARSE_MORE_TENSOR_INFORMATION="1"
export TRITONPARSE_SAVE_TENSOR_BLOBS="1"
export TRITONPARSE_TENSOR_STORAGE_QUOTA=$((10 * 1024 * 1024 * 1024)) # 10GB limit
python your_script.pyexport TRITONPARSE_DEBUG="1"
export TRITONPARSE_KERNEL_ALLOWLIST="my_kernel*"
export TORCHINDUCTOR_FX_GRAPH_CACHE="0"
python your_script.py- Usage Guide - Complete workflow examples
- Python API Reference - Python API documentation
- Installation - Setup instructions
- FAQ - Common questions