Skip to content

cuda: refactored ssm_scan and use CUB #13291

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 152 additions & 70 deletions ggml/src/ggml-cuda/ssm-scan.cu
Original file line number Diff line number Diff line change
@@ -1,87 +1,117 @@
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
#define USE_CUB
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070

#ifdef USE_CUB
#include <cub/cub.cuh>
using namespace cub;
#endif // USE_CUB

#include "ssm-scan.cuh"

template <size_t splitD, size_t N>
__global__ void __launch_bounds__(splitD, 2)
ssm_scan_f32(const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
// We would like to keep pragma unroll for cases where L_template is not 0,
// so we suppress the clang transformation warning.
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wpass-failed"
#endif // __clang__
template <size_t splitD, size_t N, size_t L_template>
__global__ void __launch_bounds__(splitD, 1)
ssm_scan_f32(const float *__restrict__ src0, const float *__restrict__ src1, const float *__restrict__ src2,
const float *__restrict__ src3, const float *__restrict__ src4, const float *__restrict__ src5,
const int32_t * __restrict__ src6, float * __restrict__ dst,
const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3,
const int src2_nb1, const int src2_nb2, const int src3_nb1,
const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
const int64_t s_off, const int64_t d_inner, const int64_t L) {

constexpr int warp_size = ggml_cuda_get_physical_warp_size();
const int bidx = blockIdx.x; // split along B (sequences)
const int bidy = blockIdx.y; // split along D (d_inner)
const int tid = threadIdx.x;
const int wid = tid / 32;
const int wtid = tid % 32;

extern __shared__ float smem[];
const int stride_sA = N + 1;
const int stride_ss0 = N + 1;
float * smem_A = smem;
float * smem_s0 = smem_A + splitD * stride_sA;

const float * s0_block = (const float *) ((const char *) src0 + src6[bidx] * src0_nb3 + bidy * splitD * src0_nb2);
const float * x_block = (const float *) ((const char *) src1 + (bidx * src1_nb3) + bidy * splitD * sizeof(float));
const float * dt_block = (const float *) ((const char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof(float));
const float * A_block = (const float *) ((const char *) src3 + bidy * splitD * src3_nb1);
const float * B_block = (const float *) ((const char *) src4 + (bidx * src4_nb3));
const float * C_block = (const float *) ((const char *) src5 + (bidx * src5_nb3));
float * y_block = (float *) ((char *) dst + (bidx * d_inner * L * sizeof(float)) + bidy * splitD * sizeof(float));
float * s_block = (float *) ((char *) dst + s_off + bidx * src0_nb3 + bidy * splitD * src0_nb2);

const int stride_s0 = src0_nb2 / sizeof(float);
const int stride_x = src1_nb2 / sizeof(float);
const int64_t s_off, const int64_t d_inner, const int64_t L_param)
{
const size_t L = L_template == 0 ? L_param : L_template;
const float *s0_block = (const float *)((const char *)src0 + src6[blockIdx.x] * src0_nb3 + blockIdx.y * splitD * src0_nb2);
const float *x_block = (const float *)((const char *)src1 + (blockIdx.x * src1_nb3) + blockIdx.y * splitD * sizeof(float));
const float *dt_block = (const float *)((const char *)src2 + (blockIdx.x * src2_nb2) + blockIdx.y * splitD * sizeof(float));
const float *A_block = (const float *)((const char *)src3 + blockIdx.y * splitD * src3_nb1);
const float *B_block = (const float *)((const char *)src4 + (blockIdx.x * src4_nb3));
const float *C_block = (const float *)((const char *)src5 + (blockIdx.x * src5_nb3));
float *y_block = (float *)((char *)dst + (blockIdx.x * d_inner * L * sizeof(float)) + blockIdx.y * splitD * sizeof(float));
float *s_block = (float *)((char *)dst + s_off + blockIdx.x * src0_nb3 + blockIdx.y * splitD * src0_nb2);

const int stride_x = src1_nb2 / sizeof(float);
const int stride_dt = src2_nb1 / sizeof(float);
const int stride_A = src3_nb1 / sizeof(float);
const int stride_B = src4_nb2 / sizeof(float);
const int stride_C = src5_nb2 / sizeof(float);
const int stride_s = stride_s0;
const int stride_y = d_inner;
const int stride_B = src4_nb2 / sizeof(float);
const int stride_C = src5_nb2 / sizeof(float);
const int stride_y = d_inner;

// can N not be 16? for example 32?
if (N == 16) {
#pragma unroll
for (size_t i = 0; i < splitD / 4; i += 2) {
float value = A_block[(wid * warp_size + i) * stride_A + wtid];
// todo: bank conflict
// I am always confused with how to use the swizzling method to solve
// bank conflit. Hoping somebody can tell me.
smem_A[(wid * warp_size + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
}
float regA[N];
float regs0[N];

__shared__ float smemB[N];
__shared__ float smemC[N];

#ifdef USE_CUB
using BlockLoad = cub::BlockLoad<float, splitD, N, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
using BlockStore = cub::BlockStore<float, splitD, N, cub::BLOCK_STORE_WARP_TRANSPOSE>;

union CubTempStorage {
typename BlockLoad::TempStorage load_temp;
typename BlockStore::TempStorage store_temp;
};
__shared__ CubTempStorage cub_temp_storage;

BlockLoad(cub_temp_storage.load_temp).Load(A_block, regA);
BlockLoad(cub_temp_storage.load_temp).Load(s0_block, regs0);
#else
const int stride_s0 = src0_nb2 / sizeof(float);
const int stride_A = src3_nb1 / sizeof(float);
#pragma unroll
for (size_t i = 0; i < splitD / 4; i += 2) {
float value = s0_block[(wid * warp_size + i) * stride_s0 + wtid];
smem_s0[(wid * warp_size + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
}
for (size_t n = 0; n < N; ++n)
{
regA[n] = A_block[threadIdx.x * stride_A + n];
regs0[n] = s0_block[threadIdx.x * stride_s0 + n];
Comment on lines +68 to +69
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The memory access pattern here is inefficient though I also wouldn't know how to improve it.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the problem lie in that the loads aren't coalesced? Wouldn't using a coalesced loading pattern require the data to be in a different layout?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the problem is the uncoalesced I/O. If you could somehow re-write the kernel to make the loads coalesced or change the memory pattern the previous kernel puts out the performance would likely be better. (I did not try to analyze whether something like this is possible.)

}
#endif

__syncthreads();
#pragma unroll
for (size_t i = 0; i < L; i++)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

L is not known at compile time in the L_template == 0 case here, which means the #pragma unroll causes a warning when this is compiled via llvm.
At least for llvm, you can just remove the pragma as the compiler unrolls this loop anyhow for the L_template != 0 case.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried removing the #pragma unroll and compared the output from Nsight Compute after running a quick test to make sure again. It makes a difference for CUDA, even in the case where L isn't known at compile time for some reason. Without explicitly unrolling the loop, it uses 2 more registers per thread. I could suppress the warning like in softmax.cu where the same sort of thing is done.

Copy link
Collaborator

@JohannesGaessler JohannesGaessler Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my experience the CUDA compiler is very conservative when it comes to unrolling loops so my preference would definitely be to keep the #pragma unroll and suppress the warning.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes a difference for CUDA, even in the case where L isn't known at compile time for some reason. Without explicitly unrolling the loop, it uses 2 more registers per thread.

Thats really strange and sounds like a mild compiler bug.
Anyhow, suppressing the warning is sufficant for me.

{
if (threadIdx.x < N)
{
smemB[threadIdx.x] = B_block[i * stride_B + threadIdx.x];
smemC[threadIdx.x] = C_block[i * stride_C + threadIdx.x];
}
__syncthreads();

for (int64_t i = 0; i < L; i++) {
float dt_soft_plus = dt_block[i * stride_dt + tid];
if (dt_soft_plus <= 20.0f) {
dt_soft_plus = log1pf(exp(dt_soft_plus));
float dt_soft_plus = dt_block[i * stride_dt + threadIdx.x];
if (dt_soft_plus <= 20.0f)
{
dt_soft_plus = log1pf(expf(dt_soft_plus));
}
float x_dt = x_block[i * stride_x + tid] * dt_soft_plus;
float x_dt = x_block[i * stride_x + threadIdx.x] * dt_soft_plus;

float sumf = 0.0f;
#pragma unroll
for (size_t j = 0; j < N; j++) {
float state = (smem_s0[tid * stride_ss0 + j] * expf(dt_soft_plus * smem_A[tid * stride_sA + j])) +
(B_block[i * stride_B + j] * x_dt);
sumf += state * C_block[i * stride_C + j];
if (i == L - 1) {
s_block[tid * stride_s + j] = state;
} else {
smem_s0[tid * stride_ss0 + j] = state;
}
for (size_t n = 0; n < N; n++)
{
float state = regs0[n] * expf(dt_soft_plus * regA[n]) + smemB[n] * x_dt;
sumf += state * smemC[n];
regs0[n] = state;
}
__syncthreads();
y_block[i * stride_y + tid] = sumf;
y_block[i * stride_y + threadIdx.x] = sumf;
}

#ifdef USE_CUB
BlockStore(cub_temp_storage.store_temp).Store(s_block, regs0);
#else
const int stride_s = stride_s0;
#pragma unroll
for (size_t n = 0; n < N; ++n)
{
s_block[threadIdx.x * stride_s + n] = regs0[n];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The memory access pattern here is also inefficient.

}
#endif
}
#ifdef __clang__
#pragma clang diagnostic pop
#endif // __clang__

// assumes as many threads as d_state
template <int splitH, int d_state>
Expand Down Expand Up @@ -201,11 +231,11 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim,
const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq,
cudaStream_t stream) {
const int threads = 128;
// NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
if (src3_nb1 == sizeof(float)) {
// Mamba-2
if (d_state == 128) {
const int threads = 128;
GGML_ASSERT(d_state % threads == 0);
// NOTE: can be any power of two between 4 and 64
const int splitH = 16;
Expand All @@ -229,18 +259,70 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
GGML_ABORT("doesn't support d_state!=(128 or 256).");
}
} else {
const int threads = 128;
// Mamba-1
GGML_ASSERT(n_head % threads == 0);
GGML_ASSERT(head_dim == 1);
GGML_ASSERT(n_group == 1);
const dim3 blocks(n_seq, (n_head + threads - 1) / threads, 1);
const int smem_size = (threads * (d_state + 1) * 2) * sizeof(float);
if (d_state == 16) {
ssm_scan_f32<128, 16><<<blocks, threads, smem_size, stream>>>(
src0, src1, src2, src3, src4, src5, src6, dst,
switch (n_tok)
{
case 1:
ssm_scan_f32<threads, 16, 1><<<blocks, threads, smem_size, stream>>>(
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
break;
case 2:
ssm_scan_f32<threads, 16, 2><<<blocks, threads, smem_size, stream>>>(
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
break;
case 3:
ssm_scan_f32<threads, 16, 3><<<blocks, threads, smem_size, stream>>>(
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
break;
case 4:
ssm_scan_f32<threads, 16, 4><<<blocks, threads, smem_size, stream>>>(
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
break;
case 5:
ssm_scan_f32<threads, 16, 5><<<blocks, threads, smem_size, stream>>>(
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
break;
case 6:
ssm_scan_f32<threads, 16, 6><<<blocks, threads, smem_size, stream>>>(
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
break;
case 7:
ssm_scan_f32<threads, 16, 7><<<blocks, threads, smem_size, stream>>>(
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
break;
case 8:
ssm_scan_f32<threads, 16, 8><<<blocks, threads, smem_size, stream>>>(
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
break;
default:
ssm_scan_f32<threads, 16, 0><<<blocks, threads, smem_size, stream>>>(
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
break;
}
} else {
GGML_ABORT("doesn't support d_state!=16.");
}
Expand Down