Skip to content

Commit f7ad6b1

Browse files
jananisriramfacebook-github-bot
authored andcommitted
[Inductor][FP8] Validate exhaustive autotuning for FP8 Inductor templates (pytorch#161442)
Summary: X-link: meta-pytorch/tritonbench#355 Pull Request resolved: pytorch#161442 Validate exhaustive autotuning for FP8 Inductor templates: scaled MM templates require `block_k >= 32`. Before, exhaustive autotuning defaulted to a limited set of autotuning configs, as limitations for exhaustively autotuning on FP8 shapes had not been tested. Test Plan: ``` CUDA_VISIBLE_DEVICES=0 TRITON_PRINT_AUTOTUNING=1 TRITON_ALWAYS_COMPILE=1 TORCH_LOGS=+inductor TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 ENABLE_PERSISTENT_TMA_MATMUL=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_SEARCH_SPACE=DEFAULT buck2 run mode/{opt,inplace} pytorch/t ritonbench:run -- --op fp8_gemm --only torch_fp8_gemm,pt2_fp8_gemm --metrics tflops,accuracy --input-loader=/home/jananisriram/personal/exhaustive_autotune_rowwise_persistent_tma/json_fi les/rowwise_ptma_0.json --output="/home/jananisriram/personal/exhaustive_autotune_rowwise_persistent_tma/autotune/gpu0_bench.csv" --atol=1e-2 --rtol=0.5 2>&1 | tee ~/personal/exhaustive_ autotune_rowwise_persistent_tma/autotune/gpu0.log ``` autotunes on the maximum configs available, rather than the defaults, and skips configs not compatible with TMA. Rollback Plan: Differential Revision: D80958642
1 parent 7376111 commit f7ad6b1

File tree

1 file changed

+6
-13
lines changed

1 file changed

+6
-13
lines changed

torch/_inductor/template_heuristics.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1522,9 +1522,9 @@ class ScaledTMAConfigMixin(ScaledMMConfigMixin):
15221522

15231523
def _filter_configs(self, configs: list[BaseConfig]) -> list[BaseConfig]:
15241524
"""
1525-
TMA specific filtering, as num_warps=2 not safe for TMA
1525+
TMA specific filtering
15261526
"""
1527-
configs = [c for c in configs if c.num_warps != 2]
1527+
configs = [c for c in configs if c.num_warps != 2 and getattr(c, "block_k", None) is not None and c.block_k >= 32]
15281528
return super()._filter_configs(configs)
15291529

15301530
def get_template_configs(
@@ -1571,7 +1571,6 @@ def __init__(self) -> None:
15711571
self.mm_configs = self.extra_mm_configs
15721572
self.exhaustive_configs = self.extra_mm_configs
15731573

1574-
15751574
# TODO(coconutruben): replace with template.name once templates are importable
15761575
@register_template_heuristic(
15771576
"mm_persistent_tma", "cuda", register=torch.version.hip is None
@@ -1596,11 +1595,10 @@ def __init__(self) -> None:
15961595
super().__init__()
15971596
# Override mm_configs to use scaled_mm_configs
15981597
self.mm_configs = self.scaled_mm_configs
1599-
# NOTE: overriding exhaustive configs here to be the same as mm_configs
1600-
# as we haven't validated exhaustive support here yet
1601-
# TODO(coconutruben): remove this once we have validated exhaustive support
1602-
# for scaled_mm
1603-
self.exhaustive_configs = self.scaled_mm_configs
1598+
1599+
def _filter_configs(self, configs: list[BaseConfig]) -> list[BaseConfig]:
1600+
configs = [c for c in configs if getattr(c, "block_k", None) is not None and c.block_k >= 32]
1601+
return super()._filter_configs(configs)
16041602

16051603

16061604
# TODO(coconutruben): replace with template.name once templates are importable
@@ -1614,11 +1612,6 @@ def __init__(self) -> None:
16141612
super().__init__()
16151613
# Override mm_configs to use scaled_persistent_mm_configs for TMA
16161614
self.mm_configs = self.scaled_persistent_mm_configs
1617-
# NOTE: overriding exhaustive configs here to be the same as mm_configs
1618-
# as we haven't validated exhaustive support here yet
1619-
# TODO(coconutruben): remove this once we have validated exhaustive support
1620-
# for scaled_mm
1621-
self.exhaustive_configs = self.scaled_persistent_mm_configs
16221615

16231616

16241617
# TODO(coconutruben): replace with template.name once templates are importable

0 commit comments

Comments
 (0)