Skip to content

Conversation

mnehete32
Copy link
Contributor

Follow up of PR: #15635

Convolution Performance Results (Old)

FP32 (float32) Performance

Input Shape Kernel Shape Runs Time/Run (µs) FLOPs/Run GFLOPS
[19,19,256,16] [4,4,256,4096] 3 368871.67 137.42 GFLOP 372.55
[19,19,8,16] [4,4,8,128] 2992 399.90 133.69 MFLOP 334.32
[19,19,8,16] [4,4,8,130] 2948 407.43 135.78 MFLOP 333.27
[19,19,4,16] [2,2,4,4] 131072 8.05 642.82 kFLOP 79.89
[224,224,3,1] [3,3,3,8] 14358 103.77 20.90 MFLOP 201.38
[224,224,1,1] [2,2,1,8] 24576 55.27 2.78 MFLOP 50.39
[224,224,1,8] [2,2,1,8] 4489 437.02 22.28 MFLOP 50.98
[58,58,32,1] [3,3,32,64] 3468 368.43 115.40 MFLOP 313.23
[58,58,32,8] [3,3,32,64] 436 2888.17 923.24 MFLOP 319.66
[16,16,128,8] [3,3,128,512] 220 5489.08 1.85 GFLOP 336.83

FP16 (float16) Performance

Input Shape Kernel Shape Runs Time/Run (µs) FLOPs/Run GFLOPS
[19,19,256,16] [4,4,256,4096] 3 403320.33 137.42 GFLOP 340.73
[19,19,8,16] [4,4,8,128] 2244 448.02 133.69 MFLOP 298.41
[19,19,8,16] [4,4,8,130] 2211 455.74 135.78 MFLOP 297.94
[19,19,4,16] [2,2,4,4] 122880 8.63 642.82 kFLOP 74.47
[224,224,3,1] [3,3,3,8] 9572 116.88 20.90 MFLOP 178.78
[224,224,1,1] [2,2,1,8] 24576 60.58 2.78 MFLOP 45.97
[224,224,1,8] [2,2,1,8] 4489 474.92 22.28 MFLOP 46.91
[58,58,32,1] [3,3,32,64] 2601 411.38 115.40 MFLOP 280.53
[58,58,32,8] [3,3,32,64] 327 3302.93 923.24 MFLOP 279.52
[16,16,128,8] [3,3,128,512] 165 6181.65 1.85 GFLOP 299.09

Convolution Performance Results (New)

FP32 (float32) Performance

Input Shape Kernel Shape Runs Time/Run (µs) FLOPs/Run TFLOPS
[19,19,256,16] [4,4,256,4096] 12 87193.67 137.42 GFLOP 1.58
[19,19,8,16] [4,4,8,128] 10472 97.68 133.69 MFLOP 1.37
[19,19,8,16] [4,4,8,130] 5896 180.83 135.78 MFLOP 0.75
[19,19,4,16] [2,2,4,4] 40960 24.69 642.82 kFLOP 0.026
[224,224,3,1] [3,3,3,8] 4786 254.30 20.90 MFLOP 0.082
[224,224,1,1] [2,2,1,8] 8192 130.07 2.78 MFLOP 0.021
[224,224,1,8] [2,2,1,8] 4489 1038.51 22.28 MFLOP 0.021
[58,58,32,1] [3,3,32,64] 5202 199.36 115.40 MFLOP 0.579
[58,58,32,8] [3,3,32,64] 872 1248.11 923.24 MFLOP 0.740
[16,16,128,8] [3,3,128,512] 660 1631.43 1.85 GFLOP 1.13

FP16 (float16) Performance

Input Shape Kernel Shape Runs Time/Run (µs) FLOPs/Run TFLOPS
[19,19,256,16] [4,4,256,4096] 26 38639.81 137.42 GFLOP 3.56
[19,19,8,16] [4,4,8,128] 19448 51.80 133.69 MFLOP 2.58
[19,19,8,16] [4,4,8,130] 11792 88.61 135.78 MFLOP 1.53
[19,19,4,16] [2,2,4,4] 81920 13.16 642.82 kFLOP 0.049
[224,224,3,1] [3,3,3,8] 9572 126.75 20.90 MFLOP 0.165
[224,224,1,1] [2,2,1,8] 16384 67.38 2.78 MFLOP 0.041
[224,224,1,8] [2,2,1,8] 4489 529.18 22.28 MFLOP 0.042
[58,58,32,1] [3,3,32,64] 11271 92.85 115.40 MFLOP 1.24
[58,58,32,8] [3,3,32,64] 1744 588.97 923.24 MFLOP 1.57
[16,16,128,8] [3,3,128,512] 1320 781.39 1.85 GFLOP 2.37

Convolution Performance Comparison (Old vs New)

FP32 (float32)

Input Shape Kernel Shape Old GFLOPS New GFLOPS Improvement
[19,19,256,16] [4,4,256,4096] 372.55 1580 +4.2×
[19,19,8,16] [4,4,8,128] 334.32 1370 +4.1×
[19,19,8,16] [4,4,8,130] 333.27 751 +2.3×
[19,19,4,16] [2,2,4,4] 79.89 26.04 -67%
[224,224,3,1] [3,3,3,8] 201.38 82.17 -59%
[224,224,1,1] [2,2,1,8] 50.39 21.41 -57%
[224,224,1,8] [2,2,1,8] 50.98 21.45 -58%
[58,58,32,1] [3,3,32,64] 313.23 578.89 +1.85×
[58,58,32,8] [3,3,32,64] 319.66 739.71 +2.3×
[16,16,128,8] [3,3,128,512] 336.83 1130 +3.36×

FP16 (float16)

Input Shape Kernel Shape Old GFLOPS New GFLOPS Improvement
[19,19,256,16] [4,4,256,4096] 340.73 3560 +10.5×
[19,19,8,16] [4,4,8,128] 298.41 2580 +8.6×
[19,19,8,16] [4,4,8,130] 297.94 1530 +5.1×
[19,19,4,16] [2,2,4,4] 74.47 48.84 -34%
[224,224,3,1] [3,3,3,8] 178.78 164.86 -8%
[224,224,1,1] [2,2,1,8] 45.97 41.33 -10%
[224,224,1,8] [2,2,1,8] 46.91 42.10 -10%
[58,58,32,1] [3,3,32,64] 280.53 1240 +4.4×
[58,58,32,8] [3,3,32,64] 279.52 1570 +5.6×
[16,16,128,8] [3,3,128,512] 299.09 2370 +7.9×

Summary:

  • Large convolutions now see 3–10× improvement in GFLOPS.
  • Small convolutions may see lower GFLOPS due to memory-bound performance (shared memory), not compute.
  • FP16 gains are more significant than FP32 on large kernels.
  • used ggml_cuda_cast<T> cast, to make sure it doesnt break build.

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Sep 5, 2025
* removed flash-attenion definition
Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you're going to make your own primitives anyways, take a look at mma.cuh. The WMMA interface NVIDIA provides for the "high-level" CUDA code is quite frankly terrible, so I exposed the tensor core PTX instructions (assembly equivalent). The practical upside is that you get a defined memory layout (important for mul_mat_q and FlashAttention but I think not here) and that you can use smaller matrix tiles (minimum is 16x8). The downside is that Volta and AMD are still lacking an implementation.

@Green-Sky
Copy link
Collaborator

Green-Sky commented Sep 5, 2025

It is becoming increasingly hard to test these kind of changes with sd.cpp, @ggerganov please sync ggml, it has been 2.5 weeks of rapid convolution development. :)

@JohannesGaessler
Copy link
Collaborator

I only monitor the llama.cpp and ggml repositories but when it comes to convolution kernels such as this it would also be fine for me if you open PRs in sd.cpp and tag me.

@ggerganov
Copy link
Member

ggml repo is up-to-date now

@Green-Sky
Copy link
Collaborator

Green-Sky commented Sep 5, 2025

@mnehete32 please run tests, the output seems to be broken.

output_2

@mnehete32
Copy link
Contributor Author

@Green-Sky Checking

@mnehete32
Copy link
Contributor Author

I think the kernel was not able to launch complete threads, as launcher launches warps per each WMMA_M, WMMA_N, i will work with launch fewer threads per block, also with the mma.cuh , it looks like I don't need shared memory to store results, I haven't checked completely yet. Also looking into it.

@JohannesGaessler
Copy link
Collaborator

it looks like I don't need shared memory to store results

FYI, for tensor cores you theoretically don't need shared memory at all. Each thread in a warp holds fractions of the input and output tiles in its registers. You only need shared memory to organize the data in such a way that the global memory accesses are coalesced (see mmf.cu) or in the case of WMMA to work around the memory layout being undefined.

@mnehete32
Copy link
Contributor Author

I thought because, output to thread mapping is unknown, it changes based on architecture. I first need to load output in shared memory before storing.

@JohannesGaessler
Copy link
Collaborator

If you read the PTX documentation you'll find that all tensor core instructions have a well-defined memory layout. It's only when you try to cover all tensor core instructions with a simple interface that you run into problems. Volta has 8x8 tensor cores. Turing, Ampere, and Ada Lovelace have 16x8 tensor cores (used by mma.cuh). Hopper has some special asynchronous tensor cores, Blackwell has yet another instructions. I don't think either of the latter two would fit WMMA, but the 16x8 instructions are still available.

@mnehete32
Copy link
Contributor Author

@mnehete32 please run tests, the output seems to be broken.

output_2

Warps were not covering the whole block, i have fixed the issue. I have tested in sd.cpp
@Green-Sky

@mnehete32
Copy link
Contributor Author

Convolution Performance

FP32 (float32)

Input Shape Kernel Shape Time (µs/run) FLOPs/run Perf
[19,19,256,16] [4,4,256,4096] 117,662.67 137.42 GF 1.17 TFLOPS
[19,19,8,16] [4,4,8,128] 138.88 133.69 MF 962.63 GFLOPS
[19,19,8,16] [4,4,8,130] 182.86 135.78 MF 742.55 GFLOPS
[19,19,4,16] [2,2,4,4] 15.21 642.82 kF 42.27 GFLOPS
[224,224,3,1] [3,3,3,8] 157.88 20.90 MF 132.35 GFLOPS
[224,224,1,1] [2,2,1,8] 80.05 2.78 MF 34.79 GFLOPS
[224,224,1,8] [2,2,1,8] 627.91 22.28 MF 35.48 GFLOPS
[58,58,32,1] [3,3,32,64] 137.30 115.40 MF 840.54 GFLOPS
[58,58,32,8] [3,3,32,64] 875.12 923.24 MF 1.05 TFLOPS
[16,16,128,8] [3,3,128,512] 1752.53 1.85 GF 1.05 TFLOPS

FP16 (float16)

Input Shape Kernel Shape Time (µs/run) FLOPs/run Perf
[19,19,256,16] [4,4,256,4096] 141,360.62 137.42 GF 972.14 GFLOPS
[19,19,8,16] [4,4,8,128] 190.27 133.69 MF 702.63 GFLOPS
[19,19,8,16] [4,4,8,130] 213.62 135.78 MF 635.63 GFLOPS
[19,19,4,16] [2,2,4,4] 15.85 642.82 kF 40.55 GFLOPS
[224,224,3,1] [3,3,3,8] 272.45 20.90 MF 76.70 GFLOPS
[224,224,1,1] [2,2,1,8] 174.22 2.78 MF 15.98 GFLOPS
[224,224,1,8] [2,2,1,8] 1387.19 22.28 MF 16.06 GFLOPS
[58,58,32,1] [3,3,32,64] 136.13 115.40 MF 847.78 GFLOPS
[58,58,32,8] [3,3,32,64] 999.01 923.24 MF 924.15 GFLOPS
[16,16,128,8] [3,3,128,512] 1875.82 1.85 GF 985.64 GFLOPS

FP32 (Old vs New)

Input Shape Kernel Shape Before Perf After Perf Speedup
[19,19,256,16] [4,4,256,4096] 372.55 GFLOPS 1.17 TFLOPS 3.1×
[19,19,8,16] [4,4,8,128] 334.32 GFLOPS 962.63 GFLOPS 2.9×
[19,19,8,16] [4,4,8,130] 333.27 GFLOPS 742.55 GFLOPS 2.2×
[19,19,4,16] [2,2,4,4] 79.89 GFLOPS 42.27 GFLOPS 0.5× (slower)
[224,224,3,1] [3,3,3,8] 201.38 GFLOPS 132.35 GFLOPS 0.7× (slower)
[224,224,1,1] [2,2,1,8] 50.39 GFLOPS 34.79 GFLOPS 0.7× (slower)
[224,224,1,8] [2,2,1,8] 50.98 GFLOPS 35.48 GFLOPS 0.7× (slower)
[58,58,32,1] [3,3,32,64] 313.23 GFLOPS 840.54 GFLOPS 2.7×
[58,58,32,8] [3,3,32,64] 319.66 GFLOPS 1.05 TFLOPS 3.3×
[16,16,128,8] [3,3,128,512] 336.83 GFLOPS 1.05 TFLOPS 3.1×

FP16 (Old vs New)

Input Shape Kernel Shape Before Perf After Perf Speedup
[19,19,256,16] [4,4,256,4096] 340.73 GFLOPS 972.14 GFLOPS 2.9×
[19,19,8,16] [4,4,8,128] 298.41 GFLOPS 702.63 GFLOPS 2.4×
[19,19,8,16] [4,4,8,130] 297.94 GFLOPS 635.63 GFLOPS 2.1×
[19,19,4,16] [2,2,4,4] 74.47 GFLOPS 40.55 GFLOPS 0.5× (slower)
[224,224,3,1] [3,3,3,8] 178.78 GFLOPS 76.70 GFLOPS 0.4× (slower)
[224,224,1,1] [2,2,1,8] 45.97 GFLOPS 15.98 GFLOPS 0.3× (slower)
[224,224,1,8] [2,2,1,8] 46.91 GFLOPS 16.06 GFLOPS 0.3× (slower)
[58,58,32,1] [3,3,32,64] 280.53 GFLOPS 847.78 GFLOPS 3.0×
[58,58,32,8] [3,3,32,64] 279.52 GFLOPS 924.15 GFLOPS 3.3×
[16,16,128,8] [3,3,128,512] 299.09 GFLOPS 985.64 GFLOPS 3.3×

@mnehete32
Copy link
Contributor Author

In my benchmarks FP16 actually slower than FP32. This is surprising since FP16 should normally be faster on Tensor Cores.
what I might be missing in my FP16 implementation? Any hints would help a lot.
@JohannesGaessler

@mnehete32
Copy link
Contributor Author

Note: not commited

Also with:

#define BS_OC     16
#define BS_ICKHKW 16
#define BS_NOHOW  128

FP32

Input Shape Kernel Shape Before Perf After Speedup (vs Old)
[19,19,256,16] [4,4,256,4096] 372.55 GFLOPS 822.58 GFLOPS 2.2×
[19,19,8,16] [4,4,8,128] 334.32 GFLOPS 647.14 GFLOPS 1.9×
[19,19,8,16] [4,4,8,130] 333.27 GFLOPS 610.44 GFLOPS 1.8×
[19,19,4,16] [2,2,4,4] 79.89 GFLOPS 67.26 GFLOPS 0.8× (slower)
[224,224,3,1] [3,3,3,8] 201.38 GFLOPS 275.45 GFLOPS 1.4×
[224,224,1,1] [2,2,1,8] 50.39 GFLOPS 102.93 GFLOPS 2.0×
[224,224,1,8] [2,2,1,8] 50.98 GFLOPS 111.18 GFLOPS 2.2×
[58,58,32,1] [3,3,32,64] 313.23 GFLOPS 554.56 GFLOPS 1.8×
[58,58,32,8] [3,3,32,64] 319.66 GFLOPS 643.95 GFLOPS 2.0×
[16,16,128,8] [3,3,128,512] 336.83 GFLOPS 601.00 GFLOPS 1.8×

FP16

Input Shape Kernel Shape Before Perf After (Old) Speedup (vs Old)
[19,19,256,16] [4,4,256,4096] 340.73 GFLOPS 670.02 GFLOPS 2.0×
[19,19,8,16] [4,4,8,128] 298.41 GFLOPS 554.50 GFLOPS 1.9×
[19,19,8,16] [4,4,8,130] 297.94 GFLOPS 546.65 GFLOPS 1.8×
[19,19,4,16] [2,2,4,4] 74.47 GFLOPS 47.86 GFLOPS 0.6× (slower)
[224,224,3,1] [3,3,3,8] 178.78 GFLOPS 243.70 GFLOPS 1.4×
[224,224,1,1] [2,2,1,8] 45.97 GFLOPS 66.82 GFLOPS 1.5×
[224,224,1,8] [2,2,1,8] 46.91 GFLOPS 74.33 GFLOPS 1.6×
[58,58,32,1] [3,3,32,64] 280.53 GFLOPS 565.64 GFLOPS 2.0×
[58,58,32,8] [3,3,32,64] 279.52 GFLOPS 609.17 GFLOPS 2.2×
[16,16,128,8] [3,3,128,512] 299.09 GFLOPS 557.83 GFLOPS 1.9×

@Green-Sky
Copy link
Collaborator

768x1024 sd1 fp16 vae:

method time memory
CUDA imcol+mul ~1.68s 4992.19 MB
CUDA direct (master) ~35.35s 1920.19 MB
CUDA direct (ac5e0c0) (this pr) ~9.00s 1920.19 MB
CUDA implicitgemm (2ec76aa) ~2.20s 1920.19 MB
VULKAN imcol+mul OOM ~4992 MB
VULKAN direct ~1.17s 1920.19 MB

Here is the table from #15805 extended with numbers from this pr.

I used the current pushed changes, you teased new numbers with better performance, but you still have not pushed those :) .

@mnehete32
Copy link
Contributor Author

mnehete32 commented Sep 13, 2025

@Green-Sky I was waiting, with which configuration, we will go with. If this #15813 (comment) new configuration is approved, I will update PR.

Edit: As with committed configuration, it give better performance when OC is large, with new configuration the performance get lower than last configuration for large OC. but overall performance is good with new configuration.

So i was waiting for approval.

@mnehete32
Copy link
Contributor Author

Updated PR with new configuration.

@Green-Sky
Copy link
Collaborator

@Green-Sky I was waiting, with which configuration, we will go with. If this #15813 (comment) new configuration is approved, I will update PR.

Edit: As with committed configuration, it give better performance when OC is large, with new configuration the performance get lower than last configuration for large OC. but overall performance is good with new configuration.

So i was waiting for approval.

I see now. Yes, it is slightly slower in practice with sd.cpp (10.77s, same test).

@mnehete32
Copy link
Contributor Author

I should revert, previous commit?

@JohannesGaessler
Copy link
Collaborator

In my benchmarks FP16 actually slower than FP32. This is surprising since FP16 should normally be faster on Tensor Cores.
what I might be missing in my FP16 implementation?

The problem most likely has to do with either shared memory bank conflicts (you need a stride of 16 bytes between rows/columns) or excessive type conversions.

@JohannesGaessler
Copy link
Collaborator

Regarding which kernel we move forward first: I'm fine with either, we probably need both to optimally cover all hardware. The in my opinion ideal solution would be to first move the decision for the conv2d vs. im2col+gemm paths into the backends, and to then write CUDA-specific logic for choosing between convolution with tensor cores, convolution without tensor cores, and im2col+gemm.

@JohannesGaessler
Copy link
Collaborator

Anyways, my criterion for approving and merging either one of the PRs as-is is for them to preserve existing functionality (as the tensor core PR would currently render some GPUs unusable) and to either be consistently faster than master or to add logic for selecting the optimal code paths depending on hardware and tensors.

@mnehete32
Copy link
Contributor Author

In my benchmarks FP16 actually slower than FP32. This is surprising since FP16 should normally be faster on Tensor Cores.
what I might be missing in my FP16 implementation?

The problem most likely has to do with either shared memory bank conflicts (you need a stride of 16 bytes between rows/columns) or excessive type conversions.

Nsight Compute shows uncoalesced shared memory access, which I think comes from the tensor core loads. The stall wait is about 3.4 cycles on average, compared to 3.2 for FP32, so the extra delay is mostly due to type conversion.

@mnehete32
Copy link
Contributor Author

Anyways, my criterion for approving and merging either one of the PRs as-is is for them to preserve existing functionality (as the tensor core PR would currently render some GPUs unusable) and to either be consistently faster than master or to add logic for selecting the optimal code paths depending on hardware and tensors.

This PR also includes handling for GPUs without tensor cores, so they won’t be affected.

@mnehete32
Copy link
Contributor Author

mnehete32 commented Sep 15, 2025

I wanted to check — should I make further changes to this PR, or would it be better to close it since it hasn’t been accepted due to performance concerns?
@JohannesGaessler

@JohannesGaessler
Copy link
Collaborator

I'm fine with either.

@mnehete32
Copy link
Contributor Author

New result

Optimization Highlights:

  1. Shrunk shared memory usage to get better occupancy.
  2. Added extra warps—not all are doing compute; some just help with loading data into shared memory.
  3. Made sure the extra warps don’t do any unnecessary computation.

FP32

Input Shape Kernel Shape Before Perf After Speedup (vs Old)
[19,19,256,16] [4,4,256,4096] 372.55 GFLOPS 1.52 TFLOPS 4.1×
[19,19,8,16] [4,4,8,128] 334.32 GFLOPS 1.33 TFLOPS 4.0×
[19,19,8,16] [4,4,8,130] 333.27 GFLOPS 1.13 TFLOPS 3.4×
[19,19,4,16] [2,2,4,4] 79.89 GFLOPS 102.07 GFLOPS 1.3×
[224,224,3,1] [3,3,3,8] 201.38 GFLOPS 296.39 GFLOPS 1.5×
[224,224,1,1] [2,2,1,8] 50.39 GFLOPS 93.37 GFLOPS 1.9×
[224,224,1,8] [2,2,1,8] 50.98 GFLOPS 97.86 GFLOPS 1.9×
[58,58,32,1] [3,3,32,64] 313.23 GFLOPS 1.18 TFLOPS 3.8×
[58,58,32,8] [3,3,32,64] 319.66 GFLOPS 1.32 TFLOPS 4.1×
[16,16,128,8] [3,3,128,512] 336.83 GFLOPS 1.33 TFLOPS 3.9×

FP16 (Tensor Core)

Input Shape Kernel Shape Before Perf After Speedup (vs Old)
[19,19,256,16] [4,4,256,4096] 340.73 GFLOPS 1.97 TFLOPS 5.8×
[19,19,8,16] [4,4,8,128] 298.41 GFLOPS 1.64 TFLOPS 5.5×
[19,19,8,16] [4,4,8,130] 297.94 GFLOPS 1.37 TFLOPS 4.6×
[19,19,4,16] [2,2,4,4] 74.47 GFLOPS 109.93 GFLOPS 1.5×
[224,224,3,1] [3,3,3,8] 178.78 GFLOPS 367.34 GFLOPS 2.1×
[224,224,1,1] [2,2,1,8] 45.97 GFLOPS 114.22 GFLOPS 2.5×
[224,224,1,8] [2,2,1,8] 46.91 GFLOPS 121.36 GFLOPS 2.6×
[58,58,32,1] [3,3,32,64] 280.53 GFLOPS 1.40 TFLOPS 5.0×
[58,58,32,8] [3,3,32,64] 279.52 GFLOPS 1.75 TFLOPS 6.3×
[16,16,128,8] [3,3,128,512] 299.09 GFLOPS 1.75 TFLOPS 5.9×

FP16 (without Tensor Core)

Input Shape Kernel Shape Before Perf After Speedup (vs Old)
[19,19,256,16] [4,4,256,4096] 340.73 GFLOPS 1.43 TFLOPS 4.2×
[19,19,8,16] [4,4,8,128] 298.41 GFLOPS 1.26 TFLOPS 4.2×
[19,19,8,16] [4,4,8,130] 297.94 GFLOPS 1.07 TFLOPS 3.6×
[19,19,4,16] [2,2,4,4] 74.47 GFLOPS 101.22 GFLOPS 1.4×
[224,224,3,1] [3,3,3,8] 178.78 GFLOPS 289.25 GFLOPS 1.6×
[224,224,1,1] [2,2,1,8] 45.97 GFLOPS 90.87 GFLOPS 2.0×
[224,224,1,8] [2,2,1,8] 46.91 GFLOPS 95.02 GFLOPS 2.0×
[58,58,32,1] [3,3,32,64] 280.53 GFLOPS 1.12 TFLOPS 4.0×
[58,58,32,8] [3,3,32,64] 279.52 GFLOPS 1.32 TFLOPS 4.7×
[16,16,128,8] [3,3,128,512] 299.09 GFLOPS 1.28 TFLOPS 4.3×

@Green-Sky @JohannesGaessler

CUDA: uint to int and added assertion
@Green-Sky
Copy link
Collaborator

Green-Sky commented Sep 16, 2025

Updated table:

768x1024 sd1 fp16 vae:

method time memory
CUDA imcol+mul ~1.68s 4992.19 MB
CUDA direct (master) ~35.35s 1920.19 MB
CUDA direct (ac5e0c0) (this pr old) ~9.00s 1920.19 MB
CUDA direct (6049576) (this pr) ~5.05s 1920.19 MB
CUDA implicitgemm (2ec76aa) ~2.20s 1920.19 MB
VULKAN imcol+mul OOM ~4992 MB
VULKAN direct ~1.17s 1920.19 MB

@etasnadi
Copy link
Contributor

etasnadi commented Sep 17, 2025

Hi, I just noticed this commit.

Has anyone tested how the perf compares to the Vulkan implementation? There is also a coopmat2 impl that uses matrix cores.

Edit: I see that there is a 2-5x improvement compared to the scalar kernel, but the Vulkan scalar kernel is 8-10x faster than the CUDA's, so it would be good to test if it outperforms the Vulkan scalar kernel at least, but it should be compared to the coopmat2 version to be fair.

@Green-Sky
Copy link
Collaborator

Hi, I just noticed this commit.

Has anyone tested how the perf compares to the Vulkan implementation? There is also a coopmat2 impl that uses matrix cores.

Edit: I see that there is a 2-5x improvement compared to the scalar kernel, but the Vulkan scalar kernel is 8-10x faster than the CUDA's, so it would be good to test if it outperforms the Vulkan scalar kernel at least, but it should be compared to the coopmat2 version to be fair.

At least in the context of stable-diffusion.cpp VAE decode, my table provides coopmat2 numbers.
I know it not very exhaustive, but still.

@etasnadi
Copy link
Contributor

etasnadi commented Sep 17, 2025

Indeed, this kernel is considerably slower (0.3-0.5x) than the basic scalar Vulkan shader on my device (2060).

Then where is the catch? Might be the perf test cases are not representative for sd.cpp?

Vulkan:

  CONV_2D(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     41 runs - 24893.68 us/run - 137.42 GFLOP/run -   5.52 TFLOPS
  CONV_2D(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):               20196 runs -    51.37 us/run - 133.69 MFLOP/run -   2.60 TFLOPS
  CONV_2D(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):               14003 runs -    72.59 us/run - 135.78 MFLOP/run -   1.87 TFLOPS
  CONV_2D(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                196608 runs -     5.18 us/run - 642.82 kFLOP/run - 123.98 GFLOPS
  CONV_2D(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                47860 runs -    22.75 us/run -  20.90 MFLOP/run - 918.35 GFLOPS
  CONV_2D(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                73728 runs -    13.97 us/run -   2.78 MFLOP/run - 199.36 GFLOPS
  CONV_2D(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                13467 runs -    92.25 us/run -  22.28 MFLOP/run - 241.49 GFLOPS
  CONV_2D(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):               22542 runs -    45.59 us/run - 115.40 MFLOP/run -   2.53 TFLOPS
  CONV_2D(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                3161 runs -   320.96 us/run - 923.24 MFLOP/run -   2.88 TFLOPS
  CONV_2D(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     1595 runs -   642.80 us/run -   1.85 GFLOP/run -   2.88 TFLOPS
  CONV_2D(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     42 runs - 23886.05 us/run - 137.42 GFLOP/run -   5.75 TFLOPS
  CONV_2D(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):               20196 runs -    51.05 us/run - 133.69 MFLOP/run -   2.62 TFLOPS
  CONV_2D(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):               14003 runs -    72.37 us/run - 135.78 MFLOP/run -   1.88 TFLOPS
  CONV_2D(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                196608 runs -     5.15 us/run - 642.82 kFLOP/run - 124.77 GFLOPS
  CONV_2D(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                47860 runs -    22.66 us/run -  20.90 MFLOP/run - 922.24 GFLOPS
  CONV_2D(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                73728 runs -    13.95 us/run -   2.78 MFLOP/run - 199.68 GFLOPS
  CONV_2D(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                13467 runs -    92.79 us/run -  22.28 MFLOP/run - 240.10 GFLOPS
  CONV_2D(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):               22542 runs -    45.19 us/run - 115.40 MFLOP/run -   2.55 TFLOPS
  CONV_2D(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                3161 runs -   320.30 us/run - 923.24 MFLOP/run -   2.88 TFLOPS
  CONV_2D(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     1595 runs -   640.66 us/run -   1.85 GFLOP/run -   2.89 TFLOPS

This (6049576):

ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
  Device 0: NVIDIA GeForce RTX 2060 SUPER, compute capability 7.5, VMM: yes
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 2060 SUPER (NVIDIA) | uma: 0 | fp16: 1 | bf16: 0 | warp size: 32 | shared memory: 49152 | int dot: 0 | matrix cores: KHR_coopmat
Testing 3 devices

Backend 1/3: CUDA0
  Device description: NVIDIA GeForce RTX 2060 SUPER
  Device memory: 7787 MB (7693 MB free)
  CONV_2D(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     12 runs - 86240.25 us/run - 137.42 GFLOP/run -   1.59 TFLOPS

  CONV_2D(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):               11220 runs -    94.04 us/run - 133.69 MFLOP/run -   1.42 TFLOPS

  CONV_2D(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                9581 runs -   111.03 us/run - 135.78 MFLOP/run -   1.22 TFLOPS
  CONV_2D(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                163840 runs -     6.13 us/run - 642.82 kFLOP/run - 104.89 GFLOPS
  CONV_2D(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                19144 runs -    60.96 us/run -  20.90 MFLOP/run - 342.80 GFLOPS
  CONV_2D(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                40960 runs -    25.34 us/run -   2.78 MFLOP/run - 109.92 GFLOPS
  CONV_2D(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 8978 runs -   187.20 us/run -  22.28 MFLOP/run - 119.01 GFLOPS
  CONV_2D(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):               12138 runs -    83.67 us/run - 115.40 MFLOP/run -   1.38 TFLOPS
  CONV_2D(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                1744 runs -   590.27 us/run - 923.24 MFLOP/run -   1.56 TFLOPS
  CONV_2D(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                      880 runs -  1203.80 us/run -   1.85 GFLOP/run -   1.54 TFLOPS
  CONV_2D(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     16 runs - 63580.88 us/run - 137.42 GFLOP/run -   2.16 TFLOPS
  CONV_2D(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):               14212 runs -    73.23 us/run - 133.69 MFLOP/run -   1.83 TFLOPS
  CONV_2D(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):               11792 runs -    86.66 us/run - 135.78 MFLOP/run -   1.57 TFLOPS
  CONV_2D(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                188416 runs -     5.45 us/run - 642.82 kFLOP/run - 117.93 GFLOPS
  CONV_2D(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                23930 runs -    48.73 us/run -  20.90 MFLOP/run - 428.86 GFLOPS
  CONV_2D(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                57344 runs -    20.10 us/run -   2.78 MFLOP/run - 138.53 GFLOPS
  CONV_2D(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 8978 runs -   151.65 us/run -  22.28 MFLOP/run - 146.91 GFLOPS
  CONV_2D(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):               14739 runs -    68.26 us/run - 115.40 MFLOP/run -   1.69 TFLOPS
  CONV_2D(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                2289 runs -   451.17 us/run - 923.24 MFLOP/run -   2.05 TFLOPS
  CONV_2D(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     1100 runs -   910.24 us/run -   1.85 GFLOP/run -   2.03 TFLOPS
  Backend CUDA0: OK
Backend 2/3: Vulkan0
  Skipping
Backend 3/3: CPU
  Skipping
3/3 backends passed
OK

@etasnadi
Copy link
Contributor

Updated table:

768x1024 sd1 fp16 vae:

method time memory
CUDA imcol+mul ~1.68s 4992.19 MB
CUDA direct (master) ~35.35s 1920.19 MB
CUDA direct (ac5e0c0) (this pr old) ~9.00s 1920.19 MB
CUDA direct (6049576) (this pr) ~5.05s 1920.19 MB
CUDA implicitgemm (2ec76aa) ~2.20s 1920.19 MB
VULKAN imcol+mul OOM ~4992 MB
VULKAN direct ~1.17s 1920.19 MB

So the Vulkan direct conv 2d version is 4.5x faster?

@mnehete32
Copy link
Contributor Author

#16088 is a better implementation.

@mnehete32 mnehete32 closed this Sep 18, 2025
@etasnadi
Copy link
Contributor

etasnadi commented Sep 18, 2025 via email

@mnehete32
Copy link
Contributor Author

As I mentioned in my earlier PR comment, I’d be happy to take a shot at an fp16 version.

@etasnadi
Copy link
Contributor

As I mentioned in my earlier PR comment, I’d be happy to take a shot at an fp16 version.

Ok, but I still think that instead of closing this it would be better to rebase to the branch in #16088 then replace the dot product loop with your ggml_cuda_mma implementation if you are still motivated. I think that the code will be way faster with mma on large problems like the 4096 by 4096 one.

@mnehete32
Copy link
Contributor Author

I’m still up for it, and I agree MMA should give a nice boost on big cases like 4096×4096. I’ll give it a shot.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants