Skip to content

Conversation

@samremes
Copy link
Contributor

@samremes samremes commented Oct 22, 2025

Proposed changes

Introduces 2d block scale support for B matrix (grouping both on N and K axes). The tile distribution for the scale matrix has different options depending on the group size.

Checklist

Please put an x into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.

  • I have added tests relevant to the introduced functionality, and the unit tests are passing locally
  • I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, IF the test takes more than 30 seconds to run.
  • I have added inline documentation which enables the maintainers with understanding the motivation
  • I have removed the stale documentation which is no longer relevant after this pull request
  • (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request
  • I have run clang-format on all changed files
  • Any dependent changes have been merged

Discussion

If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered

Copy link
Contributor

@ThomasNing ThomasNing left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great PR!

static constexpr index_t KPerBlock = BlockGemmShape::kK;

static constexpr index_t NQPerBlock = NPerBlock / QuantGroupSize::kN;
static constexpr index_t BQPerBlock = KPerBlock / QuantGroupSize::kK;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name should be KQPerBlock in that case.


static_assert(KPerBlock % QuantGroupSize == 0,
static_assert(KPerBlock % QuantGroupSize::kK == 0,
"KPerBlock must be a multiple of QuantGroupSize");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

static_assert should be "KPerBlock must be a multiple of QuantGroupSize in K dim."

static constexpr index_t XR = 2;
CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution()
{
if constexpr(YPerQ == 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

YPerQ is the 1D blockscale case.

{
// YPerQ == 1 implementation - each row of B has independent scale
constexpr index_t X = XPerTile;
constexpr index_t XR = 2;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

XR could be set to 1 in that case.

index_t XPerTile,
index_t YPerQ,
index_t VecSize>
struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding_pattern
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function looks good to me. @CongMa13 Could you also take a look at it and help me confirm?

index_t KPerBlockAQ,
index_t VecSize,
bool PreshuffleQuant>
struct tile_distribution_encoding_pattern_aq : public tile_distribution_encoding_pattern
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we also have the similar change to this function as Bq?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not yet for aquant. I think that should go pretty similar to bquant. I did some initial refactoring to support the group sizes for all M/N/K but kernel implementation only for the bquant.

@ThomasNing
Copy link
Contributor

Please also address the Merge conflict.

@aosewski aosewski requested a review from Copilot October 28, 2025 11:40
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR introduces 2D block scale support for the B matrix in GEMM operations, enabling quantization grouping on both the N and K axes. Previously, only 1D grouping (along the K axis) was supported for B matrix quantization.

  • Refactored QuantGroupSize from a simple integer constant to a structured type containing M, N, and K dimensions
  • Added multiple 2D block size configurations (8N, 16N, 64N, 128N) for testing B matrix quantization
  • Updated tile distribution logic to handle different group sizes with specialized patterns

Reviewed Changes

Copilot reviewed 21 out of 21 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
test/ck_tile/gemm_block_scale/test_gemm_quant_typed.cpp Adds 2D block size test configurations and updates GroupSize definitions
test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp Updates test fixtures to use structured QuantGroupSize type and handle 2D blocks
test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp Changes QuantGroupSize from integer constant to type alias
include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp Introduces QuantGroupShape struct and implements conditional tile distribution patterns
include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp Updates pipeline problem definition to use QuantGroupSize type
include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp Updates BQuant pipeline to calculate NPerBlockBQ and use structured QuantGroupSize
include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp Updates weight preshuffle pipeline for new QuantGroupSize structure
include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_base_policy.hpp Adds NPerBlockBQ calculation in policy
include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp Updates policy to use QuantGroupSize::kN and ::kK
include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp Adds NPerBlockBQ and validates block dimensions
include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp Updates AQuant pipeline for structured QuantGroupSize
include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp Updates AQuant policy to use QuantGroupSize::kK
include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp Updates memory pipeline for structured QuantGroupSize
include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp Updates base implementation to use QuantGroupSize::kK
include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp Updates kernel to handle 2D block scales with proper indexing
include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp Implements conditional scale indexing based on NQPerBlock
include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp Updates AQuant block to use QuantGroupSize::kK
include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp Updates preshuffle block for structured QuantGroupSize
include/ck_tile/host/reference/reference_gemm.hpp Updates reference implementation to handle 2D quantization
example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc Updates example to use QuantGroupSize type
example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp Updates example instantiations with structured QuantGroupSize

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

K, // M, N, K
0, // QK_A (not used for BQuant)
BQK, // QK_B
BQK, // QK_B - TODO: we can remove BQK and BQN from args later?
Copy link

Copilot AI Oct 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO comment suggests incomplete refactoring. Consider creating a tracked issue for this technical debt or addressing it in this PR if feasible.

Suggested change
BQK, // QK_B - TODO: we can remove BQK and BQN from args later?
BQK, // QK_B - FIXME: See issue #1234 - refactor to remove BQK and BQN from args if not needed

Copilot uses AI. Check for mistakes.
Comment on lines +191 to +192
CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution()
{
Copy link

Copilot AI Oct 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This complex function with multiple conditional branches handling different YPerQ values would benefit from a documentation comment explaining the different cases and when each branch is selected.

Suggested change
CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution()
{
/**
* Returns a static tile distribution encoding for quantized GEMM, handling different cases
* based on YPerQ (number of rows per quantization scale) and YPerTile (tile size along Y).
*
* Branches:
* 1. YPerQ == 1:
* - Each row of B has an independent scale.
* - Distribution splits Y into (NIterPerWarp, NWarps, WarpGemm::kN).
* - Used when the finest granularity of quantization is required.
*
* 2. YPerTile >= NIterPerWarp * NWarps:
* - All warps in the block use the same scale.
* - Distribution replicates the scale across warps.
* - Used when quantization scale covers the entire block tile along Y.
*
* 3. YPerTile >= NIterPerWarp:
* - All NWarps have the same scale, replicated across warps.
* - Used when quantization scale covers multiple iterations per warp.
*
* 4. Otherwise:
* - Larger NQ block size, multiple iters/warps use the same scale, replicated to all threads.
* - Used when quantization scale is coarser than the block tile.
*
* This function ensures the correct distribution of quantization scales for each tile,
* optimizing memory access and computation based on the quantization granularity.
*/
CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution()

Copilot uses AI. Check for mistakes.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is acutally a nice documentation! Can we add it?

using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;

static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!");
static_assert(QuantGroupSize::kN == 1, "no block for N supported yet!");
Copy link

Copilot AI Oct 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The second assertion prevents N-axis blocking in the preshuffle kernel, which conflicts with the PR's goal of supporting 2D block scales. This should be relaxed or the preshuffle kernel should be updated to support N-axis blocking.

Suggested change
static_assert(QuantGroupSize::kN == 1, "no block for N supported yet!");
// static_assert(QuantGroupSize::kN == 1, "no block for N supported yet!");

Copilot uses AI. Check for mistakes.
@illsilin
Copy link
Collaborator

Hi @samremes, could you please resolve the merge conflicts?

Comment on lines 62 to 73
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize>,

std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,

// 2d cases with grouping also on the n axis
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D>,
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D>,
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D>,
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is awesome to have these unit tests 👍

std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,

// 2d cases with grouping also on the n axis
// FIXME: why is group size 8 not working?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

group size cannot be smaller than GemmWarp::kN since warp tile is the smallest compute tile.

std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D16N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D16N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D16N>,
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tried to add group size 32 and the test case failed. Could you please try it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't get it to work either. I think the current tile distributions won't be able to handle it. I changed the conditions to be more specific that the split into NWarps/NIterPerWarp has to be exact.


std::string quant_mode = arg_parser.get_str("quant_mode");

using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we make the Quant Group Size as an interface? Currently, we need to manually put the quant dim size.

@ThomasNing
Copy link
Contributor

@CongMa13 Please try the solution we discussed of the tile distribution today and see the perf difference.

@samremes
Copy link
Contributor Author

@ThomasNing @CongMa13 Did you have some ideas for the tile distribution? I think the current versions require that it exactly splits with NWarps and/or NIterPerWarp.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants