-
Notifications
You must be signed in to change notification settings - Fork 12.6k
HIP: Enable Matrix cores for MMQ Kernels, Enable stream-K for CDNA 3 #14624
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
I would be happy to get on a call with you to discuss AMD hardware support, my email address can be found on my Github page. |
@deepsek Thanks for the contribution and for reaching out. On topics related to the CUDA backend, @JohannesGaessler is the best person to consult with. For additional backends, @slaren can provide guidelines and advice. I'll be happy to provide input on any matters as well. I am also available for call - feel free to contact me. |
Very nice to see the initiative. I assume improvements made for CDNA will also swap into the consumer side next year when UDNA releases. So this is exciting news for the future of AMD products! |
This certainly is good news |
Sorry, I wanted to ask: @IMbackK since you've been working on AMD support, are you interested in joining the discussion? |
Yes, certainly. It would help to avoid duplication of effort. i can be reached via email at uvos.xyz user carl |
Hi @JohannesGaessler, is there any blocker for merging this PR to the main branch? |
@deepsek There are a few small things as discussed, better naming for this mfma path so that a rdna wmma solution can be added later without the nameing being strange is one thing, use of two V_MFMA_I32_16X16X16I8 instructions on gfx908 and gfx90a, even if this path is not chosen for those, to ease maintainability is another. I would also like to try this myself on gfx94x somehow and i am not sure what the state is with regard to access to amds cloud for maintenance of a gfx94x specific code path, maybe @ggerganov can also comment on that. A problem here being that after cdna2/gfx90a/mi210 AMD has not made any further CDNA devices that are in a pcie addon board form factor, so out side of the acquisition of an entire mi300 oam machine no one can simply add a CDNA3/gfx94x/MI3xx compatible card to their system. |
@deepsek upps, sorry i accidentally edited your post instead of quoting it, please repost. From my side there is nothing further missing, id just like to give it another spin to test for regressions, i will approve after. |
No worries haha. I was just saying. Based on all the comments so far, looks like
Is there actual guidance as to when performance is preferred over memory usage? Looks like there are conflicting viewpoints. Would be great to have this information documentation for when we contribute and add other architectures, there is a common design principle. P.S.,
|
if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8> | ||
#pragma unroll | ||
for (int l = 0; l < t.ne; ++l) { | ||
t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)]; | ||
} | ||
} else { | ||
int64_t * xi = (int64_t *) t.x; | ||
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I)); | ||
xi[0] = xs[0]; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8> | |
#pragma unroll | |
for (int l = 0; l < t.ne; ++l) { | |
t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)]; | |
} | |
} else { | |
int64_t * xi = (int64_t *) t.x; | |
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I)); | |
xi[0] = xs[0]; | |
} | |
if constexpr (I != 64 || J != 2) { | |
int64_t * xi = (int64_t *) t.x; | |
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I)); | |
xi[0] = xs[0]; | |
return; | |
} |
I think this would be simpler.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean without the preprocessor directives?
This would affect the NV code path when we call load_generic though? I see some instances where load_generic is called
static __device__ __forceinline__ void load_generic(...) {
if constexpr (I != 64 || J != 2) {
int64_t * xi = (int64_t *) t.x;
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
xi[0] = xs[0];
return;
}
#pragma unroll
for (int l = 0; l < t.ne; ++l) {
t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
}
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I basically meant to have the instructions for loading data as 64 bit encapsulated in an ifdef AMD_MFMA_AVAILABLE ... #endif
and to use the generic implementation if the preconditions aren't met. But if this is going to be refactored anyways it doesn't matter.
I don't understand what you're doing with |
It's very hard to have a single design principle for every situation. Historically, some models would need as much as 6 GB of VRAM for the dequantization buffer. This would cause a lot of problems for people using consumer GPUs who would not understand why they could not offload to the GPU as many layers of the model as they would expect. This was one of the reasons why MMQ was made the default, even thought it was not faster than cuBLAS in every situation. Ultimately, it doesn't matter if cuBLAS is 10% faster, if using it means that you need to keep a large portion of the model in a CPU that is 10 times slower. For a data center GPU where VRAM is not so limited, the calculus may be different. |
I ran the following tests on my RTX 3090/4090:
With
On my RTX 4090 for all quantization formats except q2_K MMQ is faster for batch sizes <= 2048, for batch sizes <= 1024 MMQ is always faster. On my RTX 3090 MMQ is faster for all quantization formats except q2_K at batch sizes <= 512. |
The MMQ code is designed around tensor core instructions that were introduced with Ampere. Hopper can also make use of these instructions but they have additional tensor core instructions that are only found on Hopper and to my knowledge no earlier or later generation. Presumably the cuBLAS code makes use of these instructions, the MMQ code definitely does not. I have never tested the code on or tuned it for Hopper.
The tensor core instructions that MMQ was written around are not available on Turing. However, there are similar tensor core instruction which work on tiles that are exactly half as large and as such the same numerical result can be obtained by executing 2 tensor core instructions instead. I do not own any Turing hardware and have not tuned the code specifically for this architecture.
No one wrote down the exact criteria for when to use cuBLAS vs. MMQ. My general opinion is that if you use anything below q8_0 you are already trading quality for lower memory usage. The hardware that I have been focusing on is RTX 3090/4090 because those are in my opinion the best "cheap" GPUs with 24 GB VRAM; on those MMQ performs well enough that I think using it unconditionally is the correct choice. On an RTX 2080 ti with only 11 GB VRAM it's even more important to keep memory usage low so I decided to enable MMQ unconditionally as well under the assumption that the tradeoffs would be similar to Ampere/Ada Lovelace. The logic on master for AMD was written by @IMbackK . I don't remember whether he posted the performance numbers upon which his decisions were based in the relevant PRs. In any case I don't have a good overview of the AMD hardware stack and decided to go with his judgement. For CDNA3 in particular my understanding is that all GPUs using that architecture have at least 128 GB of memory. For that particular hardware I therefore think that the correct choice is to simply maximize speed.
I have documented the design decisions that seemed unintuitive to me at the time. However, I think that it is generally difficult to judge which parts of your own work need to be documented in order to make it understandable to third parties. Since you have already gone through the trouble of understanding the code as-is I would be grateful if you could add comments in those places where they would have helped your understanding. |
Sounds good. Great to hear the intuition process on these. |
The purpose of this is to leverage the 16x16x32 mfma instr (16x8 tile) over 32x32x16 (32x4 tile). This gives some perf increase and also fixes nwarps to 8 for all quants. I use this specific 'placeholder' tile to load the same <16x4> tile twice as a <16x8> tile since the current arch matrix core can't support 16x16x16 instr. With this compute, the result is basically double the value needed, hence in these cases the scale value is halved in the code. load_tile is set to <64,2> only because the tile calculates number of elements and I need it to stay at 2 (== 64*2/64). Since, <16,8>, <32,4> is taken. Also, hence the tag as a special tile used to achieve this. |
If I understand you correctly you are solving the problem where some of the quantization types have 1 scale per 4 32 bit integers so the 16x8 tiles are too long and don't yield the correct results. Have you considered loading the data as 16x8 tiles as you do in e.g.
and a corresponding function |
Yea. I was initially going down that same road when I started remove the larger (32x4) tiles. I was going to use a bitmask to simply clear unused threads. But that would require quite a few more changes in the code and how the loops are written right now. In the interest of saving time and prevent this PR from dangling too long, I just choose the other route to achieve the same with minimal changes. We can revisit this as part of a larger redesign at a later date if you'd like. |
I don't think the current solution is good but I'm willing to approve the PR as-is as long as you promise to refactor the code in the near future. |
Needed manual rebase but works great and on multiple AMD Radeon RX 7900 XT cards resolves "Page not present or supervisor privilege" gpu crash on k-shift. |
This pr should not affect gfx11 in any way, thats also the first i hear of this problem. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have tested this pr one more time and am satisfied that there are no further regressions except the small regression in the dp4a code path in Q4_0 and Q4_1 on gfx9xx, which i find acceptable give the performance benefit in other areas.
@deepsek Given the overall similarity between MFMA and WMMA int8 instructions, is there a plan to enable int8 mmq for RDNA3/4? It could greatly improve performance at least on RDNA4 due to its doubled int8 mma throughput over fp16. |
@hjc4869 Yes, there is currently investigation going on to do exactly that! Assuming everything goes smoothly, should expect something to drop from the team in the near future. @jiachengjason |
To be clear, the WMMA interface as provided by e.g. NVIDIA is useless for MMQ because you don't get a defined data layout and therefore cannot apply the per-block scales correctly unless you go through shared memory (slow). That is why I went to the trouble of writing my own primitives in |
Wmma in this case means gfx11+ wmma instructions not rocwmma, slightly confusing naming. |
Hi, before PR merge:
after PR merge:
While bisecting, I also ran into one build where FA is either completely broken, or extremely well optimized, but for now I did not examine it further:
EDIT: Tracked down the FA regression to a86f52b (tag b5973), so that seems unrelated to this PR. |
I can reproduce the regression:
I don't at all understand why this is happening though since the dp4a codepath for a warp size of 32 should be unchanged. |
The performance regression should be fixed with #15014 . |
Good catch! Looks like the last line of the |
Added Matrix cores support (MFMA instructions) for MMQ kernels.
Enable stream-K for CDNA3 to work with MMQ kernels.
Removed usage of WARP_SIZE hardcoded constant in MMQ kernels.
NOTE: Thoughts on removing all uses of hardcoded const specific to only NVIDIA (like WARP_SIZE) in order to support other GPUs?
@JohannesGaessler @ggerganov
P.S. I am part of an AMD team actively working on enabling AMD feature set on llama.cpp. We would like to get on call to discuss some future PR plans for additional backends, flash attention changes, etc.
EDIT:
Update to add some performance charts for DeepSeekV3 model.
Upstream vs ROCm Fork Development

MI300X vs H100 Throughput Test
