Skip to content

Commit 9c5b464

Browse files
authored
[None][feat] Apply AutoTuner to fp8_block_scale_deep_gemm to trigger JIT ahead of time. (#7113)
Because deep_gemm.gp8_gemm_nt will trigger many JIT processes during the inference phase, we need to sweep these shapes ahead of time. Apply the AutoTuner framework to achieve this and retain the potential capability to tune the swap_ab flag. Signed-off-by: Yukun He <[email protected]>
1 parent c038fb3 commit 9c5b464

File tree

3 files changed

+106
-18
lines changed

3 files changed

+106
-18
lines changed

tensorrt_llm/_torch/custom_ops/torch_custom_ops.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import torch
55

66
import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils
7+
import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils
8+
from tensorrt_llm import deep_gemm
79
from tensorrt_llm._utils import get_sm_version
810

911
from ..autotuner import (AutoTuner, ConstraintSpec, DynamicTensorSpec,
@@ -890,6 +892,94 @@ def _(
890892
return input.new_empty((M, N), dtype=output_dtype)
891893

892894

895+
def fp8_swap_ab_gen_tuning_buckets(x: int):
896+
buckets = tuple(range(8, 128, 8))
897+
if x >= 128:
898+
buckets += tuple(range(128, x, 128))
899+
return buckets
900+
901+
902+
class fp8SwapABGemmRunner(TunableRunner):
903+
tuning_config = TuningConfig(
904+
dynamic_tensor_specs=(DynamicTensorSpec(
905+
0, 0, fp8_swap_ab_gen_tuning_buckets), ),
906+
tune_max_num_tokens=4096,
907+
)
908+
909+
def __init__(self, output_dtype: torch.dtype, disable_ue8m0_cast: bool):
910+
self.output_dtype = output_dtype
911+
self.disable_ue8m0_cast = disable_ue8m0_cast
912+
913+
def get_valid_tactics(
914+
self,
915+
inputs: List[torch.Tensor],
916+
profile: OptimizationProfile,
917+
) -> List[int]:
918+
# Encode swap_ab as False (0) and True (1). Currently only add one tactic here.
919+
return [0]
920+
921+
def forward(
922+
self,
923+
inputs: List[torch.Tensor],
924+
tactic: int = -1,
925+
) -> torch.Tensor:
926+
input, weight, weight_scale = inputs
927+
a, a_sf = fp8_utils.per_token_quant_and_transform(input)
928+
output = torch.empty(
929+
(input.size(0), weight.size(0)),
930+
device=input.device,
931+
dtype=self.output_dtype,
932+
)
933+
# TODO: add swap_ab=tactic == 0 to detemrmine the swap_ab value
934+
# Treat the default tactic=-1 as swap_ab=False
935+
deep_gemm.fp8_gemm_nt(
936+
(a, a_sf),
937+
(weight, weight_scale),
938+
output,
939+
disable_ue8m0_cast=self.disable_ue8m0_cast,
940+
)
941+
return output
942+
943+
944+
@torch.library.custom_op("trtllm::fp8_swap_ab_gemm", mutates_args=())
945+
def fp8_swap_ab_gemm(
946+
input: torch.Tensor,
947+
weight: torch.Tensor,
948+
weight_scale: torch.Tensor,
949+
output_dtype: torch.dtype = torch.bfloat16,
950+
disable_ue8m0_cast: bool = False,
951+
tune_max_num_tokens: int = 4096,
952+
) -> torch.Tensor:
953+
tuner = AutoTuner.get()
954+
fp8_swap_ab_gemm_runner = fp8SwapABGemmRunner(
955+
output_dtype,
956+
disable_ue8m0_cast,
957+
)
958+
fp8SwapABGemmRunner.tuning_config.tune_max_num_tokens = tune_max_num_tokens
959+
_, best_tactic = tuner.choose_one(
960+
"trtllm::fp8_swap_ab_gemm",
961+
[fp8_swap_ab_gemm_runner],
962+
fp8SwapABGemmRunner.tuning_config,
963+
[input, weight, weight_scale],
964+
)
965+
return fp8_swap_ab_gemm_runner(
966+
inputs=[input, weight, weight_scale],
967+
tactic=best_tactic,
968+
)
969+
970+
971+
@fp8_swap_ab_gemm.register_fake
972+
def _(
973+
input: torch.Tensor,
974+
weight: torch.Tensor,
975+
weight_scale: torch.Tensor,
976+
output_dtype: torch.dtype = torch.bfloat16,
977+
disable_ue8m0_cast: bool = False,
978+
tune_max_num_tokens: int = 4096,
979+
) -> torch.Tensor:
980+
return input.new_empty((input.size(0), weight.size(0)), dtype=output_dtype)
981+
982+
893983
def get_event(event_idx: int):
894984
from ..utils import get_model_extra_attrs
895985
extra_attrs = get_model_extra_attrs()

tensorrt_llm/_torch/modules/linear.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from torch.nn.parameter import Parameter
1313

1414
import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils
15-
import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils
1615
from tensorrt_llm._torch.peft.lora.layer import LoraLayer
1716
from tensorrt_llm.functional import (AllReduceFusionOp, AllReduceParams,
1817
AllReduceStrategy)
@@ -591,15 +590,12 @@ def apply(self, module: Linear, input: torch.Tensor,
591590
act_input_fp8, module.weight, act_input_sf,
592591
module.weight_scale)
593592
else:
594-
from tensorrt_llm import deep_gemm
595-
a, a_sf = fp8_utils.per_token_quant_and_transform(input)
596-
output = torch.empty((input.shape[0], module.weight.shape[0]),
597-
device=input.device,
598-
dtype=torch.bfloat16)
599-
deep_gemm.fp8_gemm_nt((a, a_sf),
600-
(module.weight, module.weight_scale),
601-
output,
602-
disable_ue8m0_cast=True)
593+
output = torch.ops.trtllm.fp8_swap_ab_gemm(
594+
input,
595+
module.weight,
596+
module.weight_scale,
597+
disable_ue8m0_cast=True,
598+
)
603599
else:
604600
act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128(
605601
input)

tests/unittest/_torch/thop/test_fp8_block_scale_gemm.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,11 @@
1919
import pytest
2020
import torch
2121
from _torch.helpers import (calc_diff, per_block_cast_to_fp8,
22-
per_block_cast_to_fp8_e8m0,
23-
per_token_cast_to_fp8_e8m0)
22+
per_block_cast_to_fp8_e8m0)
2423
from utils.util import getSMVersion
2524

25+
from tensorrt_llm._torch.autotuner import autotune
26+
2627

2728
@pytest.mark.skipif(
2829
getSMVersion() != 100,
@@ -46,16 +47,17 @@ def test_fp8_block_scale_deep_gemm(dtype, m, k, n):
4647
a = torch.randn((m, k), device='cuda', dtype=dtype)
4748
b = torch.randn((n, k), device='cuda', dtype=dtype)
4849

49-
act_a_fp8, act_a_sf = per_token_cast_to_fp8_e8m0(a)
5050
act_b_fp8, act_b_sf = per_block_cast_to_fp8_e8m0(b)
5151

5252
output_expected = a @ b.t()
53-
from tensorrt_llm import deep_gemm
54-
output = torch.empty((act_a_fp8.shape[0], act_b_fp8.shape[0]),
55-
device=act_a_fp8.device,
56-
dtype=torch.bfloat16)
5753

58-
deep_gemm.fp8_gemm_nt((act_a_fp8, act_a_sf), (act_b_fp8, act_b_sf), output)
54+
with autotune():
55+
output = torch.ops.trtllm.fp8_swap_ab_gemm(
56+
a,
57+
act_b_fp8,
58+
act_b_sf,
59+
)
60+
5961
diff = calc_diff(output, output_expected)
6062
assert diff < 1e-2
6163

0 commit comments

Comments
 (0)