|
4 | 4 | import torch
|
5 | 5 |
|
6 | 6 | import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils
|
| 7 | +import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils |
| 8 | +from tensorrt_llm import deep_gemm |
7 | 9 | from tensorrt_llm._utils import get_sm_version
|
8 | 10 |
|
9 | 11 | from ..autotuner import (AutoTuner, ConstraintSpec, DynamicTensorSpec,
|
@@ -890,6 +892,94 @@ def _(
|
890 | 892 | return input.new_empty((M, N), dtype=output_dtype)
|
891 | 893 |
|
892 | 894 |
|
| 895 | +def fp8_swap_ab_gen_tuning_buckets(x: int): |
| 896 | + buckets = tuple(range(8, 128, 8)) |
| 897 | + if x >= 128: |
| 898 | + buckets += tuple(range(128, x, 128)) |
| 899 | + return buckets |
| 900 | + |
| 901 | + |
| 902 | +class fp8SwapABGemmRunner(TunableRunner): |
| 903 | + tuning_config = TuningConfig( |
| 904 | + dynamic_tensor_specs=(DynamicTensorSpec( |
| 905 | + 0, 0, fp8_swap_ab_gen_tuning_buckets), ), |
| 906 | + tune_max_num_tokens=4096, |
| 907 | + ) |
| 908 | + |
| 909 | + def __init__(self, output_dtype: torch.dtype, disable_ue8m0_cast: bool): |
| 910 | + self.output_dtype = output_dtype |
| 911 | + self.disable_ue8m0_cast = disable_ue8m0_cast |
| 912 | + |
| 913 | + def get_valid_tactics( |
| 914 | + self, |
| 915 | + inputs: List[torch.Tensor], |
| 916 | + profile: OptimizationProfile, |
| 917 | + ) -> List[int]: |
| 918 | + # Encode swap_ab as False (0) and True (1). Currently only add one tactic here. |
| 919 | + return [0] |
| 920 | + |
| 921 | + def forward( |
| 922 | + self, |
| 923 | + inputs: List[torch.Tensor], |
| 924 | + tactic: int = -1, |
| 925 | + ) -> torch.Tensor: |
| 926 | + input, weight, weight_scale = inputs |
| 927 | + a, a_sf = fp8_utils.per_token_quant_and_transform(input) |
| 928 | + output = torch.empty( |
| 929 | + (input.size(0), weight.size(0)), |
| 930 | + device=input.device, |
| 931 | + dtype=self.output_dtype, |
| 932 | + ) |
| 933 | + # TODO: add swap_ab=tactic == 0 to detemrmine the swap_ab value |
| 934 | + # Treat the default tactic=-1 as swap_ab=False |
| 935 | + deep_gemm.fp8_gemm_nt( |
| 936 | + (a, a_sf), |
| 937 | + (weight, weight_scale), |
| 938 | + output, |
| 939 | + disable_ue8m0_cast=self.disable_ue8m0_cast, |
| 940 | + ) |
| 941 | + return output |
| 942 | + |
| 943 | + |
| 944 | +@torch.library.custom_op("trtllm::fp8_swap_ab_gemm", mutates_args=()) |
| 945 | +def fp8_swap_ab_gemm( |
| 946 | + input: torch.Tensor, |
| 947 | + weight: torch.Tensor, |
| 948 | + weight_scale: torch.Tensor, |
| 949 | + output_dtype: torch.dtype = torch.bfloat16, |
| 950 | + disable_ue8m0_cast: bool = False, |
| 951 | + tune_max_num_tokens: int = 4096, |
| 952 | +) -> torch.Tensor: |
| 953 | + tuner = AutoTuner.get() |
| 954 | + fp8_swap_ab_gemm_runner = fp8SwapABGemmRunner( |
| 955 | + output_dtype, |
| 956 | + disable_ue8m0_cast, |
| 957 | + ) |
| 958 | + fp8SwapABGemmRunner.tuning_config.tune_max_num_tokens = tune_max_num_tokens |
| 959 | + _, best_tactic = tuner.choose_one( |
| 960 | + "trtllm::fp8_swap_ab_gemm", |
| 961 | + [fp8_swap_ab_gemm_runner], |
| 962 | + fp8SwapABGemmRunner.tuning_config, |
| 963 | + [input, weight, weight_scale], |
| 964 | + ) |
| 965 | + return fp8_swap_ab_gemm_runner( |
| 966 | + inputs=[input, weight, weight_scale], |
| 967 | + tactic=best_tactic, |
| 968 | + ) |
| 969 | + |
| 970 | + |
| 971 | +@fp8_swap_ab_gemm.register_fake |
| 972 | +def _( |
| 973 | + input: torch.Tensor, |
| 974 | + weight: torch.Tensor, |
| 975 | + weight_scale: torch.Tensor, |
| 976 | + output_dtype: torch.dtype = torch.bfloat16, |
| 977 | + disable_ue8m0_cast: bool = False, |
| 978 | + tune_max_num_tokens: int = 4096, |
| 979 | +) -> torch.Tensor: |
| 980 | + return input.new_empty((input.size(0), weight.size(0)), dtype=output_dtype) |
| 981 | + |
| 982 | + |
893 | 983 | def get_event(event_idx: int):
|
894 | 984 | from ..utils import get_model_extra_attrs
|
895 | 985 | extra_attrs = get_model_extra_attrs()
|
|
0 commit comments