-
Notifications
You must be signed in to change notification settings - Fork 247
[CK_TILE] B matrix 2D block scale gemm #3074
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great PR!
| static constexpr index_t KPerBlock = BlockGemmShape::kK; | ||
|
|
||
| static constexpr index_t NQPerBlock = NPerBlock / QuantGroupSize::kN; | ||
| static constexpr index_t BQPerBlock = KPerBlock / QuantGroupSize::kK; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The name should be KQPerBlock in that case.
|
|
||
| static_assert(KPerBlock % QuantGroupSize == 0, | ||
| static_assert(KPerBlock % QuantGroupSize::kK == 0, | ||
| "KPerBlock must be a multiple of QuantGroupSize"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
static_assert should be "KPerBlock must be a multiple of QuantGroupSize in K dim."
| static constexpr index_t XR = 2; | ||
| CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution() | ||
| { | ||
| if constexpr(YPerQ == 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
YPerQ is the 1D blockscale case.
| { | ||
| // YPerQ == 1 implementation - each row of B has independent scale | ||
| constexpr index_t X = XPerTile; | ||
| constexpr index_t XR = 2; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
XR could be set to 1 in that case.
| index_t XPerTile, | ||
| index_t YPerQ, | ||
| index_t VecSize> | ||
| struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding_pattern |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function looks good to me. @CongMa13 Could you also take a look at it and help me confirm?
| index_t KPerBlockAQ, | ||
| index_t VecSize, | ||
| bool PreshuffleQuant> | ||
| struct tile_distribution_encoding_pattern_aq : public tile_distribution_encoding_pattern |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we also have the similar change to this function as Bq?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not yet for aquant. I think that should go pretty similar to bquant. I did some initial refactoring to support the group sizes for all M/N/K but kernel implementation only for the bquant.
|
Please also address the Merge conflict. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces 2D block scale support for the B matrix in GEMM operations, enabling quantization grouping on both the N and K axes. Previously, only 1D grouping (along the K axis) was supported for B matrix quantization.
- Refactored
QuantGroupSizefrom a simple integer constant to a structured type containing M, N, and K dimensions - Added multiple 2D block size configurations (8N, 16N, 64N, 128N) for testing B matrix quantization
- Updated tile distribution logic to handle different group sizes with specialized patterns
Reviewed Changes
Copilot reviewed 21 out of 21 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| test/ck_tile/gemm_block_scale/test_gemm_quant_typed.cpp | Adds 2D block size test configurations and updates GroupSize definitions |
| test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp | Updates test fixtures to use structured QuantGroupSize type and handle 2D blocks |
| test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp | Changes QuantGroupSize from integer constant to type alias |
| include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp | Introduces QuantGroupShape struct and implements conditional tile distribution patterns |
| include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp | Updates pipeline problem definition to use QuantGroupSize type |
| include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp | Updates BQuant pipeline to calculate NPerBlockBQ and use structured QuantGroupSize |
| include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp | Updates weight preshuffle pipeline for new QuantGroupSize structure |
| include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_base_policy.hpp | Adds NPerBlockBQ calculation in policy |
| include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp | Updates policy to use QuantGroupSize::kN and ::kK |
| include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp | Adds NPerBlockBQ and validates block dimensions |
| include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp | Updates AQuant pipeline for structured QuantGroupSize |
| include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp | Updates AQuant policy to use QuantGroupSize::kK |
| include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp | Updates memory pipeline for structured QuantGroupSize |
| include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp | Updates base implementation to use QuantGroupSize::kK |
| include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp | Updates kernel to handle 2D block scales with proper indexing |
| include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp | Implements conditional scale indexing based on NQPerBlock |
| include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp | Updates AQuant block to use QuantGroupSize::kK |
| include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp | Updates preshuffle block for structured QuantGroupSize |
| include/ck_tile/host/reference/reference_gemm.hpp | Updates reference implementation to handle 2D quantization |
| example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc | Updates example to use QuantGroupSize type |
| example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp | Updates example instantiations with structured QuantGroupSize |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| K, // M, N, K | ||
| 0, // QK_A (not used for BQuant) | ||
| BQK, // QK_B | ||
| BQK, // QK_B - TODO: we can remove BQK and BQN from args later? |
Copilot
AI
Oct 28, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO comment suggests incomplete refactoring. Consider creating a tracked issue for this technical debt or addressing it in this PR if feasible.
| BQK, // QK_B - TODO: we can remove BQK and BQN from args later? | |
| BQK, // QK_B - FIXME: See issue #1234 - refactor to remove BQK and BQN from args if not needed |
| CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution() | ||
| { |
Copilot
AI
Oct 28, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This complex function with multiple conditional branches handling different YPerQ values would benefit from a documentation comment explaining the different cases and when each branch is selected.
| CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution() | |
| { | |
| /** | |
| * Returns a static tile distribution encoding for quantized GEMM, handling different cases | |
| * based on YPerQ (number of rows per quantization scale) and YPerTile (tile size along Y). | |
| * | |
| * Branches: | |
| * 1. YPerQ == 1: | |
| * - Each row of B has an independent scale. | |
| * - Distribution splits Y into (NIterPerWarp, NWarps, WarpGemm::kN). | |
| * - Used when the finest granularity of quantization is required. | |
| * | |
| * 2. YPerTile >= NIterPerWarp * NWarps: | |
| * - All warps in the block use the same scale. | |
| * - Distribution replicates the scale across warps. | |
| * - Used when quantization scale covers the entire block tile along Y. | |
| * | |
| * 3. YPerTile >= NIterPerWarp: | |
| * - All NWarps have the same scale, replicated across warps. | |
| * - Used when quantization scale covers multiple iterations per warp. | |
| * | |
| * 4. Otherwise: | |
| * - Larger NQ block size, multiple iters/warps use the same scale, replicated to all threads. | |
| * - Used when quantization scale is coarser than the block tile. | |
| * | |
| * This function ensures the correct distribution of quantization scales for each tile, | |
| * optimizing memory access and computation based on the quantization granularity. | |
| */ | |
| CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is acutally a nice documentation! Can we add it?
include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp
Outdated
Show resolved
Hide resolved
| using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>; | ||
|
|
||
| static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!"); | ||
| static_assert(QuantGroupSize::kN == 1, "no block for N supported yet!"); |
Copilot
AI
Oct 28, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The second assertion prevents N-axis blocking in the preshuffle kernel, which conflicts with the PR's goal of supporting 2D block scales. This should be relaxed or the preshuffle kernel should be updated to support N-axis blocking.
| static_assert(QuantGroupSize::kN == 1, "no block for N supported yet!"); | |
| // static_assert(QuantGroupSize::kN == 1, "no block for N supported yet!"); |
|
Hi @samremes, could you please resolve the merge conflicts? |
| std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize>, | ||
|
|
||
| std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>, | ||
| std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>, | ||
| std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>, | ||
| std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>, | ||
|
|
||
| // 2d cases with grouping also on the n axis | ||
| std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D>, | ||
| std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D>, | ||
| std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D>, | ||
| std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is awesome to have these unit tests 👍
| std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>, | ||
|
|
||
| // 2d cases with grouping also on the n axis | ||
| // FIXME: why is group size 8 not working? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
group size cannot be smaller than GemmWarp::kN since warp tile is the smallest compute tile.
| std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D16N>, | ||
| std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D16N>, | ||
| std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D16N>, | ||
| std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tried to add group size 32 and the test case failed. Could you please try it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't get it to work either. I think the current tile distributions won't be able to handle it. I changed the conditions to be more specific that the split into NWarps/NIterPerWarp has to be exact.
|
|
||
| std::string quant_mode = arg_parser.get_str("quant_mode"); | ||
|
|
||
| using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we make the Quant Group Size as an interface? Currently, we need to manually put the quant dim size.
|
@CongMa13 Please try the solution we discussed of the tile distribution today and see the perf difference. |
|
@ThomasNing @CongMa13 Did you have some ideas for the tile distribution? I think the current versions require that it exactly splits with NWarps and/or NIterPerWarp. |
Proposed changes
Introduces 2d block scale support for B matrix (grouping both on N and K axes). The tile distribution for the scale matrix has different options depending on the group size.
Checklist
Please put an
xinto the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.clang-formaton all changed filesDiscussion
If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered