-
Notifications
You must be signed in to change notification settings - Fork 13.3k
CUDA: Conv2d Tensor Core #15813
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CUDA: Conv2d Tensor Core #15813
Conversation
* removed flash-attenion definition
57aa09e
to
2cd9fb0
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you're going to make your own primitives anyways, take a look at mma.cuh
. The WMMA interface NVIDIA provides for the "high-level" CUDA code is quite frankly terrible, so I exposed the tensor core PTX instructions (assembly equivalent). The practical upside is that you get a defined memory layout (important for mul_mat_q and FlashAttention but I think not here) and that you can use smaller matrix tiles (minimum is 16x8). The downside is that Volta and AMD are still lacking an implementation.
It is becoming increasingly hard to test these kind of changes with |
I only monitor the llama.cpp and ggml repositories but when it comes to convolution kernels such as this it would also be fine for me if you open PRs in sd.cpp and tag me. |
|
@mnehete32 please run tests, the output seems to be broken. ![]() |
@Green-Sky Checking |
I think the kernel was not able to launch complete threads, as launcher launches warps per each WMMA_M, WMMA_N, i will work with launch fewer threads per block, also with the |
FYI, for tensor cores you theoretically don't need shared memory at all. Each thread in a warp holds fractions of the input and output tiles in its registers. You only need shared memory to organize the data in such a way that the global memory accesses are coalesced (see |
I thought because, output to thread mapping is unknown, it changes based on architecture. I first need to load output in shared memory before storing. |
If you read the PTX documentation you'll find that all tensor core instructions have a well-defined memory layout. It's only when you try to cover all tensor core instructions with a simple interface that you run into problems. Volta has 8x8 tensor cores. Turing, Ampere, and Ada Lovelace have 16x8 tensor cores (used by |
Warps were not covering the whole block, i have fixed the issue. I have tested in sd.cpp |
Convolution PerformanceFP32 (float32)
FP16 (float16)
FP32 (Old vs New)
FP16 (Old vs New)
|
In my benchmarks FP16 actually slower than FP32. This is surprising since FP16 should normally be faster on Tensor Cores. |
Note: not commitedAlso with:
FP32
FP16
|
768x1024 sd1 fp16 vae:
Here is the table from #15805 extended with numbers from this pr. I used the current pushed changes, you teased new numbers with better performance, but you still have not pushed those :) . |
@Green-Sky I was waiting, with which configuration, we will go with. If this #15813 (comment) new configuration is approved, I will update PR. Edit: As with committed configuration, it give better performance when OC is large, with new configuration the performance get lower than last configuration for large OC. but overall performance is good with new configuration. So i was waiting for approval. |
Updated PR with new configuration. |
I see now. Yes, it is slightly slower in practice with sd.cpp (10.77s, same test). |
I should revert, previous commit? |
The problem most likely has to do with either shared memory bank conflicts (you need a stride of 16 bytes between rows/columns) or excessive type conversions. |
Regarding which kernel we move forward first: I'm fine with either, we probably need both to optimally cover all hardware. The in my opinion ideal solution would be to first move the decision for the conv2d vs. im2col+gemm paths into the backends, and to then write CUDA-specific logic for choosing between convolution with tensor cores, convolution without tensor cores, and im2col+gemm. |
Anyways, my criterion for approving and merging either one of the PRs as-is is for them to preserve existing functionality (as the tensor core PR would currently render some GPUs unusable) and to either be consistently faster than master or to add logic for selecting the optimal code paths depending on hardware and tensors. |
Nsight Compute shows uncoalesced shared memory access, which I think comes from the tensor core loads. The stall wait is about 3.4 cycles on average, compared to 3.2 for FP32, so the extra delay is mostly due to type conversion. |
This PR also includes handling for GPUs without tensor cores, so they won’t be affected. |
I wanted to check — should I make further changes to this PR, or would it be better to close it since it hasn’t been accepted due to performance concerns? |
I'm fine with either. |
New resultOptimization Highlights:
FP32
FP16 (Tensor Core)
FP16 (without Tensor Core)
|
74f4907
to
6e4a26a
Compare
CUDA: uint to int and added assertion
6e4a26a
to
6049576
Compare
Updated table: 768x1024 sd1 fp16 vae:
|
Hi, I just noticed this commit. Has anyone tested how the perf compares to the Vulkan implementation? There is also a coopmat2 impl that uses matrix cores. Edit: I see that there is a 2-5x improvement compared to the scalar kernel, but the Vulkan scalar kernel is 8-10x faster than the CUDA's, so it would be good to test if it outperforms the Vulkan scalar kernel at least, but it should be compared to the coopmat2 version to be fair. |
At least in the context of stable-diffusion.cpp VAE decode, my table provides coopmat2 numbers. |
Indeed, this kernel is considerably slower (0.3-0.5x) than the basic scalar Vulkan shader on my device (2060). Then where is the catch? Might be the perf test cases are not representative for sd.cpp? Vulkan:
This (6049576):
|
So the Vulkan direct conv 2d version is 4.5x faster? |
#16088 is a better implementation. |
Tensor Cores implementation is still missing and the Vulkan coopmat2 impl cannot be directly ported, so I think if you could vectorize the inner loop with wmma/mma that would be a big deal if you are still morivated to work on this.
https://etasnadi.com
…-------- Eredeti üzenet --------
2025. 09. 18. 19:46-kor, mnehete32 ezt írta:
mnehete32 left a comment [(ggml-org/llama.cpp#15813)](#15813 (comment))
[#16088](#16088) is a better implementation.
—
Reply to this email directly, [view it on GitHub](#15813 (comment)), or [unsubscribe](https://github.com/notifications/unsubscribe-auth/ACYAISZK2JLF25SHHQ6JTMT3TLVWZAVCNFSM6AAAAACFWFTONSVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZTGMBYG43TMNRQG4).
You are receiving this because you commented.Message ID: ***@***.***>
|
As I mentioned in my earlier PR comment, I’d be happy to take a shot at an fp16 version. |
Ok, but I still think that instead of closing this it would be better to rebase to the branch in #16088 then replace the dot product loop with your |
I’m still up for it, and I agree MMA should give a nice boost on big cases like 4096×4096. I’ll give it a shot. |
Follow up of PR: #15635
Convolution Performance Results (Old)
FP32 (float32) Performance
FP16 (float16) Performance
Convolution Performance Results (New)
FP32 (float32) Performance
FP16 (float16) Performance
Convolution Performance Comparison (Old vs New)
FP32 (float32)
FP16 (float16)
Summary:
ggml_cuda_cast<T>
cast, to make sure it doesnt break build.