1818#include " tensorrt_llm/common/assert.h"
1919#include " tensorrt_llm/common/cublasVersionCheck.h"
2020#include < algorithm>
21+ #include < unordered_map>
2122
2223#ifndef CUDART_VERSION
2324#error CUDART_VERSION Undefined!
@@ -63,6 +64,16 @@ void CublasMMWrapper::createDescriptors(cublasOperation_t transa, cublasOperatio
6364 mOperationDesc , CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof (cublasOperation_t)));
6465 check_cuda_error (
6566 cublasLtMatmulDescSetAttribute (mOperationDesc , CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fastAcc, sizeof (int8_t )));
67+
68+ #ifdef ENABLE_CUBLASLT_FP4_GEMM
69+ // Set pointer mode for FP4 GEMM
70+ if (mAType == CUDA_R_4F_E2M1)
71+ {
72+ cublasLtPointerMode_t pointer_mode = CUBLASLT_POINTER_MODE_DEVICE;
73+ check_cuda_error (cublasLtMatmulDescSetAttribute (
74+ mOperationDesc , CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointer_mode, sizeof (pointer_mode)));
75+ }
76+ #endif
6677}
6778
6879void CublasMMWrapper::setScaleDescriptors (void * scale_a, void * scale_b)
@@ -71,6 +82,39 @@ void CublasMMWrapper::setScaleDescriptors(void* scale_a, void* scale_b)
7182 cublasLtMatmulDescSetAttribute (mOperationDesc , CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &scale_a, sizeof (void *)));
7283 check_cuda_error (
7384 cublasLtMatmulDescSetAttribute (mOperationDesc , CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &scale_b, sizeof (void *)));
85+
86+ // Set scaling modes for FP4 GEMM
87+ if (mAType == CUDA_R_4F_E2M1)
88+ {
89+ // Set scaling mode - cuBLASLt requires e4m3 format scaling factors
90+ cublasLtMatmulMatrixScale_t AScaleMode = CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3;
91+ cublasLtMatmulMatrixScale_t BScaleMode = CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3;
92+ cublasLtMatmulMatrixScale_t CScaleMode = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F;
93+ cublasLtMatmulMatrixScale_t DScaleMode = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F;
94+ cublasLtMatmulMatrixScale_t DOutScaleMode = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F;
95+
96+ check_cuda_error (cublasLtMatmulDescSetAttribute (
97+ mOperationDesc , CUBLASLT_MATMUL_DESC_A_SCALE_MODE, &AScaleMode, sizeof (AScaleMode)));
98+ check_cuda_error (cublasLtMatmulDescSetAttribute (
99+ mOperationDesc , CUBLASLT_MATMUL_DESC_B_SCALE_MODE, &BScaleMode, sizeof (BScaleMode)));
100+ check_cuda_error (cublasLtMatmulDescSetAttribute (
101+ mOperationDesc , CUBLASLT_MATMUL_DESC_C_SCALE_MODE, &CScaleMode, sizeof (CScaleMode)));
102+ check_cuda_error (cublasLtMatmulDescSetAttribute (
103+ mOperationDesc , CUBLASLT_MATMUL_DESC_D_SCALE_MODE, &DScaleMode, sizeof (DScaleMode)));
104+ check_cuda_error (cublasLtMatmulDescSetAttribute (
105+ mOperationDesc , CUBLASLT_MATMUL_DESC_D_OUT_SCALE_MODE, &DOutScaleMode, sizeof (DOutScaleMode)));
106+
107+ // Set C/D matrix scale pointers to nullptr
108+ void const * c_scale_ptr = nullptr ;
109+ void const * d_scale_ptr = nullptr ;
110+ void const * d_out_scale_ptr = nullptr ;
111+ check_cuda_error (cublasLtMatmulDescSetAttribute (
112+ mOperationDesc , CUBLASLT_MATMUL_DESC_C_SCALE_POINTER, &c_scale_ptr, sizeof (c_scale_ptr)));
113+ check_cuda_error (cublasLtMatmulDescSetAttribute (
114+ mOperationDesc , CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &d_scale_ptr, sizeof (d_scale_ptr)));
115+ check_cuda_error (cublasLtMatmulDescSetAttribute (
116+ mOperationDesc , CUBLASLT_MATMUL_DESC_D_OUT_SCALE_POINTER, &d_out_scale_ptr, sizeof (d_out_scale_ptr)));
117+ }
74118}
75119
76120void CublasMMWrapper::setBiasDescriptor (void * bias)
@@ -247,14 +291,27 @@ void CublasMMWrapper::setFP8GemmConfig(cudaDataType_t outputType)
247291}
248292#endif
249293
294+ #ifdef ENABLE_CUBLASLT_FP4_GEMM
295+ void CublasMMWrapper::setFP4GemmConfig (cudaDataType_t outputType)
296+ {
297+ setGemmConfig (CUDA_R_4F_E2M1, CUDA_R_4F_E2M1, outputType, CUDA_R_32F);
298+ }
299+ #endif
300+
250301void CublasMMWrapper::setGemmConfig (
251302 cudaDataType_t aType, cudaDataType_t bType, cudaDataType_t cType, cudaDataType_t computeType)
252303{
253304 mAType = aType;
254305 mBType = bType;
255306 mCType = cType;
256307 bool isFp16ComputeType = computeType == CUDA_R_16F;
257- if (isFp16ComputeType)
308+ if (mAType == CUDA_R_4F_E2M1)
309+ {
310+ // for cublaslt nvfp4 gemm, fp32 compute type and fp32 scale type are required
311+ mComputeType = CUBLAS_COMPUTE_32F;
312+ mScaleType = CUDA_R_32F;
313+ }
314+ else if (isFp16ComputeType)
258315 {
259316 mComputeType = CUBLAS_COMPUTE_16F;
260317 mScaleType = CUDA_R_16F;
@@ -481,6 +538,127 @@ std::vector<cublasLtMatmulHeuristicResult_t> CublasMMWrapper::getTactics(cublasL
481538#endif
482539}
483540
541+ #ifdef ENABLE_CUBLASLT_FP4_GEMM
542+
543+ namespace
544+ {
545+ // Helper function: Get or create a zero beta tensor on GPU for the given device
546+ // Beta is always 0 for FP4 GEMM and is allocated once per device per thread
547+ float const * getBetaDevicePointer ()
548+ {
549+ thread_local static std::unordered_map<int , float *> beta_per_device;
550+
551+ int current_device;
552+ cudaGetDevice (¤t_device);
553+
554+ auto it = beta_per_device.find (current_device);
555+ if (it == beta_per_device.end ())
556+ {
557+ // Allocate GPU memory for beta and initialize to 0
558+ float * d_beta;
559+ cudaMalloc (&d_beta, sizeof (float ));
560+ cudaMemset (d_beta, 0 , sizeof (float ));
561+ beta_per_device[current_device] = d_beta;
562+ return d_beta;
563+ }
564+
565+ return it->second ;
566+ }
567+ } // namespace
568+
569+ // BlockScaleGemm Version 1: Default algorithm (uses first valid heuristic)
570+ void CublasMMWrapper::BlockScaleGemm (cublasOperation_t transa, cublasOperation_t transb, int const m, int const n,
571+ int const k, void const * A, int const lda, void const * B, int const ldb, void * C, int const ldc, void const * a_sf,
572+ void const * b_sf, float const * alpha)
573+ {
574+ // Forward to the overloaded version with nullptr (use default algorithm)
575+ BlockScaleGemm (transa, transb, m, n, k, A, lda, B, ldb, C, ldc, a_sf, b_sf, alpha, nullptr );
576+ }
577+
578+ // BlockScaleGemm Version 2: Specified algorithm (unified implementation)
579+ void CublasMMWrapper::BlockScaleGemm (cublasOperation_t transa, cublasOperation_t transb, int const m, int const n,
580+ int const k, void const * A, int const lda, void const * B, int const ldb, void * C, int const ldc, void const * a_sf,
581+ void const * b_sf, float const * alpha, cublasLtMatmulAlgo_t const * algo)
582+ {
583+ // Verify input data types (currently supports FP4, can be extended to more formats in the future)
584+ TLLM_CHECK_WITH_INFO (mAType == CUDA_R_4F_E2M1 && mBType == CUDA_R_4F_E2M1,
585+ " BlockScaleGemm currently requires FP4 input types. "
586+ " Future versions may support other quantized formats with block-wise scaling." );
587+
588+ // Validate input pointers
589+ TLLM_CHECK_WITH_INFO (A != nullptr , " A pointer is null" );
590+ TLLM_CHECK_WITH_INFO (B != nullptr , " B pointer is null" );
591+ TLLM_CHECK_WITH_INFO (C != nullptr , " C pointer is null" );
592+ TLLM_CHECK_WITH_INFO (a_sf != nullptr , " a_sf (A scale factor) pointer is null" );
593+ TLLM_CHECK_WITH_INFO (b_sf != nullptr , " b_sf (B scale factor) pointer is null" );
594+ TLLM_CHECK_WITH_INFO (alpha != nullptr , " alpha pointer is null" );
595+
596+ // Beta is always 0 for FP4 GEMM, get per-device GPU pointer
597+ float const * beta = getBetaDevicePointer ();
598+
599+ // Create descriptors for block-scaled GEMM
600+ createDescriptors (transa, transb, m, n, k, lda, ldb, ldc, 0 );
601+
602+ // Create D descriptor for output matrix
603+ cublasLtMatrixLayout_t Ddesc = NULL ;
604+ check_cuda_error (cublasLtMatrixLayoutCreate (&Ddesc, mCType , m, n, ldc));
605+
606+ // Set block-wise scaling descriptors
607+ setScaleDescriptors (const_cast <void *>(a_sf), const_cast <void *>(b_sf));
608+
609+ // Validate cuBLASLt handle
610+ TLLM_CHECK_WITH_INFO (mCublasLtHandle != nullptr , " cuBLASLt handle is null" );
611+
612+ // Determine which algorithm to use
613+ cublasLtMatmulAlgo_t const * selected_algo = algo;
614+ cublasLtMatmulAlgo_t default_algo;
615+
616+ if (algo == nullptr )
617+ {
618+ // No algorithm specified, use heuristic (default behavior)
619+ auto heuristics = getTactics (getCublasLtHandle (), mOperationDesc , mADesc , mBDesc , mCDesc , Ddesc);
620+
621+ if (heuristics.empty ())
622+ {
623+ if (Ddesc)
624+ cublasLtMatrixLayoutDestroy (Ddesc);
625+ destroyDescriptors ();
626+ throw std::runtime_error (" No suitable cuBLASLt algorithm found for block-scaled GEMM" );
627+ }
628+
629+ // Use the first valid heuristic
630+ auto const & heuristic = heuristics[0 ];
631+ bool hasAlgo = heuristic.state == CUBLAS_STATUS_SUCCESS && heuristic.workspaceSize <= CUBLAS_WORKSPACE_SIZE;
632+
633+ if (hasAlgo)
634+ {
635+ default_algo = heuristic.algo ;
636+ selected_algo = &default_algo;
637+ }
638+ else
639+ {
640+ selected_algo = nullptr ; // No valid algorithm, let cuBLASLt choose
641+ }
642+ }
643+
644+ int workspaceSize = mCublasWorkspace == NULL ? 0 : CUBLAS_WORKSPACE_SIZE;
645+
646+ // Call cuBLASLt matmul with selected or default algorithm
647+ check_cuda_error (cublasLtMatmul (getCublasLtHandle (), mOperationDesc , alpha, A, mADesc , B, mBDesc , beta, C, mCDesc ,
648+ C, Ddesc, selected_algo, // nullptr or specific algorithm
649+ mCublasWorkspace , workspaceSize, mStream ));
650+
651+ // Synchronize stream
652+ sync_check_cuda_error (mStream );
653+
654+ // Clean up descriptors
655+ if (Ddesc)
656+ cublasLtMatrixLayoutDestroy (Ddesc);
657+ destroyDescriptors ();
658+ }
659+
660+ #endif
661+
484662} // namespace common
485663
486664} // namespace tensorrt_llm
0 commit comments