-
Notifications
You must be signed in to change notification settings - Fork 709
support swapAB for m_grouped_fp8_gemm_nt_masked #192
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
block_ns.push_back(i); | ||
if(get_env<int>("ENABLE_SWAPAB")){ | ||
block_ms = std::vector{32}; // 32, 64 | ||
block_ns = std::vector{256}; // 64, 128, 256 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Manually set one of them. Experiments have found that in most cases, 256 performs the best
Thanks! Merging it later. |
f2e2357
to
2991c77
Compare
Hi~ Do you have a plan for code merge? We have already used this PR on our online service for H20. |
LGTM |
为什么在other case里 num_groups=1, expected_m_per_group=1024, n=4096, k=7168 这个case也能有提升?num groups=1时实际上相当于是一个1024×4096×7168的矩阵乘吧?SwapAB在这里的优势是什么 |
Sorry, we will try to merge this by the end of Oct. As swap AB will introduce non-batch-invariant and deterministic issues, we will consider it more carefully and do some refactors before merging. Also, as most the code can be reused, we will also refactor the epilogue part to make this feature less change for the code. Thanks for your contribution! We will refactor for you, no change request👍🏻 cc @zheanxu |
The swapAB variant “swap” the WGMMA tile usage, mapping the original problem’s M dimension onto WGMMA’s N dimension (which must be a multiple of 8). This enables smaller BLOCK_M (32). The performance advantage primarily comes from finer tiling granularity and better resource utilization. |
Thanks~ Looking forward to the release of the new version. |
SwapAB: Significantly improve the performance for M%64<32
Description
How to use
Improvements (H20)
Aligned M, desired state: masked_m[j] = int(expected_m_per_group * random.uniform(1, 1))
Other case (original test): masked_m[j] = int(expected_m_per_group * random.uniform(0.7, 1.3))
TODO