Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
8bb5255
Refactor quant group size to be configurable for M/N/K, not just K
samremes Oct 13, 2025
98365f5
add some asserts for configurations not implemented
samremes Oct 13, 2025
f6b07dc
start setting of group size for N dimension
samremes Oct 13, 2025
22362f2
enable 2d for reference quant gemm
samremes Oct 14, 2025
9988a46
WIP: trying to figure out tile dstr and/or indexing for scale matrix
samremes Oct 16, 2025
36b88c6
WIP
samremes Oct 20, 2025
bb52cd9
Fix handling of n dim blocks in tile windows etc
samremes Oct 21, 2025
f179a8a
remove commented code and enable all tests again
samremes Oct 22, 2025
d100ab6
fix formatting
samremes Oct 22, 2025
37738e4
Add more specialized tile distributions
samremes Oct 27, 2025
98deefa
Enable NWarps replication for bquant tile dstr
samremes Oct 27, 2025
2d86cd0
fix formatting
samremes Oct 27, 2025
470d6e4
Merge remote-tracking branch 'origin/develop' into samremes/bmatrix_2…
samremes Oct 27, 2025
1f13003
fix format
samremes Oct 27, 2025
a449728
Merge remote-tracking branch 'origin/develop' into samremes/bmatrix_2…
samremes Oct 28, 2025
e12ab56
Fix some issues from the merge
samremes Oct 28, 2025
7c93551
fix formatting
samremes Oct 28, 2025
e1475d4
one more fix to tile dstr, and revert debug initialization
samremes Oct 28, 2025
5e0a356
Remove commented code
samremes Oct 29, 2025
1290b1b
simplify conditions that are needed for tile distributions
samremes Oct 29, 2025
306e25a
only enable the working group sizes in tests
samremes Oct 29, 2025
68e41da
fix formatting
samremes Oct 30, 2025
bcccafe
Update tile distribution for 2D bquant
CongMa13 Oct 31, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 16 additions & 14 deletions example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ template <typename GemmConfig,
typename ALayout,
typename BLayout,
typename CLayout,
uint32_t QuantGroupSize,
typename QuantGroupSize,
ck_tile::QuantType QuantMode,
typename CDEElementWise>
float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s)
Expand Down Expand Up @@ -229,7 +229,7 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str

template <typename GemmConfig,
typename TypeConfig,
uint32_t QuantGroupSize,
typename QuantGroupSize,
ck_tile::QuantType QuantMode>
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
{
Expand Down Expand Up @@ -279,6 +279,8 @@ int run_gemm_example(int argc, char* argv[])

std::string quant_mode = arg_parser.get_str("quant_mode");

using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we make the Quant Group Size as an interface? Currently, we need to manually put the quant dim size.


if(data_type == "fp8")
{
using TypeConfig =
Expand All @@ -288,31 +290,31 @@ int run_gemm_example(int argc, char* argv[])
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
128,
QuantGroupSize,
ck_tile::QuantType::AQuantGrouped>(
a_layout, b_layout, argc, argv);
}
else if(quant_mode == "bquant")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
128,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(
a_layout, b_layout, argc, argv);
}
else if(quant_mode == "rowcol")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
128,
QuantGroupSize,
ck_tile::QuantType::RowColQuant>(
a_layout, b_layout, argc, argv);
}
else if(quant_mode == "tensor")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
128,
QuantGroupSize,
ck_tile::QuantType::TensorQuant>(
a_layout, b_layout, argc, argv);
}
Expand All @@ -331,31 +333,31 @@ int run_gemm_example(int argc, char* argv[])
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
128,
QuantGroupSize,
ck_tile::QuantType::AQuantGrouped>(
a_layout, b_layout, argc, argv);
}
else if(quant_mode == "bquant")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
128,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(
a_layout, b_layout, argc, argv);
}
else if(quant_mode == "rowcol")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
128,
QuantGroupSize,
ck_tile::QuantType::RowColQuant>(
a_layout, b_layout, argc, argv);
}
else if(quant_mode == "tensor")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
128,
QuantGroupSize,
ck_tile::QuantType::TensorQuant>(
a_layout, b_layout, argc, argv);
}
Expand All @@ -376,7 +378,7 @@ int run_gemm_example(int argc, char* argv[])
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
128,
QuantGroupSize,
ck_tile::QuantType::AQuantGrouped>(
a_layout, b_layout, argc, argv);
}
Expand All @@ -397,7 +399,7 @@ int run_gemm_example(int argc, char* argv[])
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
128,
QuantGroupSize,
ck_tile::QuantType::AQuantGrouped>(
a_layout, b_layout, argc, argv);
}
Expand All @@ -418,7 +420,7 @@ int run_gemm_example(int argc, char* argv[])
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
128,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(
a_layout, b_layout, argc, argv);
}
Expand All @@ -439,7 +441,7 @@ int run_gemm_example(int argc, char* argv[])
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
128,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(
a_layout, b_layout, argc, argv);
}
Expand Down
16 changes: 8 additions & 8 deletions example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ template <typename GemmConfig,
typename BLayout,
typename BQLayout,
typename CLayout,
uint32_t QuantGroupSize,
typename QuantGroupSize,
ck_tile::QuantType QuantMode,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
Expand Down Expand Up @@ -113,7 +113,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,

template <typename GemmConfig,
typename TypeConfig,
uint32_t QuantGroupSize,
typename QuantGroupSize,
ck_tile::QuantType QuantMode,
typename ALayout,
typename AQLayout,
Expand Down Expand Up @@ -146,7 +146,7 @@ int run_gemm_example_with_layouts(int argc,
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped ||
QuantMode == ck_tile::QuantType::BQuantGrouped)
{
if(K % QuantGroupSize != 0)
if(K % QuantGroupSize::kK != 0)
{
throw std::runtime_error(
"K must be aligned with QuantGroupSize for AQuantGrouped/BQuantGrouped mode");
Expand All @@ -155,13 +155,13 @@ int run_gemm_example_with_layouts(int argc,
ck_tile::index_t AQK, BQK;
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
{
AQK = K / QuantGroupSize; // Group quantization: AQK = K / GroupSize
BQK = 0; // No B quantization
AQK = K / QuantGroupSize::kK; // Group quantization: AQK = K / GroupSize
BQK = 0; // No B quantization
}
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
{
AQK = 0; // No A quantization
BQK = K / QuantGroupSize; // Group quantization: BQK = K / GroupSize
AQK = 0; // No A quantization
BQK = K / QuantGroupSize::kK; // Group quantization: BQK = K / GroupSize
}
else if constexpr(QuantMode == ck_tile::QuantType::RowColQuant ||
QuantMode == ck_tile::QuantType::TensorQuant)
Expand Down Expand Up @@ -357,7 +357,7 @@ int run_gemm_example_with_layouts(int argc,
if constexpr(GemmConfig::PreshuffleQuant)
{
ck_tile::HostTensor<AQDataType> aq_shuffle_host =
ck_tile::shuffle_aq(aq_tensor_ptr.get(), GemmConfig::K_Tile / QuantGroupSize);
ck_tile::shuffle_aq(aq_tensor_ptr.get(), GemmConfig::K_Tile / QuantGroupSize::kK);
aq_dev_buf_ptr->ToDevice(aq_shuffle_host.data());
}
else
Expand Down
9 changes: 4 additions & 5 deletions include/ck_tile/host/reference/reference_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
uint32_t QuantGroupSize,
typename QuantGroupSize,
bool aquant,
typename AElementOp = ck_tile::identity,
typename BElementOp = ck_tile::identity,
Expand Down Expand Up @@ -80,12 +80,11 @@ CK_TILE_HOST void reference_gemm_quant(const HostTensor<ADataType>& a_m_k,
v_block_acc += v_a * v_b;

// Apply group dequant scale
if((k + 1) % QuantGroupSize == 0)
if((k + 1) % QuantGroupSize::kK == 0)
{
float scale = 0.f;
index_t outer_dim = (aquant) ? m : k / QuantGroupSize;
index_t inner_dim = (aquant) ? k / QuantGroupSize : n;

index_t outer_dim = (aquant) ? (m / QuantGroupSize::kM) : (k / QuantGroupSize::kK);
index_t inner_dim = (aquant) ? (k / QuantGroupSize::kK) : (n / QuantGroupSize::kN);
if constexpr(std::is_same_v<QDataType, float>)
{
scale = q(outer_dim, inner_dim);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace ck_tile {

// A is block window on shared memory
// BQ (scale tensor) is block distributed tensor.
// Consecutive kQuantGroupSize elements of B are quantized with a separate scale.
// Consecutive QuantGroupSize elements of B are quantized with a separate scale.
// B is block window on block distributed tensor.
// C is block distributed tensor
template <typename Problem_, typename BlockPolicy_>
Expand All @@ -24,6 +24,10 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>; // TileFlatmmShape
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;

static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!");
static_assert(QuantGroupSize::kN == 1, "no block for N supported yet!");
Copy link

Copilot AI Oct 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The second assertion prevents N-axis blocking in the preshuffle kernel, which conflicts with the PR's goal of supporting 2D block scales. This should be relaxed or the preshuffle kernel should be updated to support N-axis blocking.

Suggested change
static_assert(QuantGroupSize::kN == 1, "no block for N supported yet!");
// static_assert(QuantGroupSize::kN == 1, "no block for N supported yet!");

Copilot uses AI. Check for mistakes.

static constexpr auto I0 = number<0>();
static constexpr auto I1 = number<1>();
Expand All @@ -47,8 +51,7 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t KPerBlock = BlockGemmShape::kK;

static constexpr index_t kQuantGroupSize = Problem::kQuantGroupSize;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kBlockSize = Problem::kBlockSize;

static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
static constexpr index_t NIterPerWarp =
Expand All @@ -58,13 +61,12 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
static constexpr auto MIter_2nd_last =
(MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1;

static constexpr index_t KPerBlockBQ = KPerBlock / kQuantGroupSize;
static constexpr index_t KPerBlockBQ = KPerBlock / QuantGroupSize::kK;

static constexpr index_t QScalesPerBlockRow =
(KPerBlock + kQuantGroupSize - 1) / kQuantGroupSize;

integer_divide_ceil(KPerBlock, QuantGroupSize::kK);
static constexpr index_t QScalesPerWarpGemmRow =
(WG::kK + kQuantGroupSize - 1) / kQuantGroupSize;
integer_divide_ceil(WG::kK, QuantGroupSize::kK);

static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow;
static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ struct BlockGemmAQuantBase

// A is block window on shared memory
// AQ (scale tensor) is block distributed tensor.
// Consecutive kQuantGroupSize elements of A are quantized with a separate scale.
// Consecutive QuantGroupSize elements of A are quantized with a separate scale.
// B is block window on shared memory
// C is block distributed tensor
template <typename Problem_,
Expand All @@ -66,16 +66,16 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;

static constexpr index_t kQuantGroupSize = Problem::kQuantGroupSize;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr auto Scheduler = Problem::Scheduler;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr auto Scheduler = Problem::Scheduler;

// Threadblock GEMM tile size
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t AQPerBlock = KPerBlock / kQuantGroupSize;
static constexpr index_t AQPerBlock = KPerBlock / QuantGroupSize::kK;

static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
Expand All @@ -101,20 +101,20 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;

static constexpr index_t QScalesPerBlockRow =
(KPerBlock + kQuantGroupSize - 1) / kQuantGroupSize;
integer_divide_ceil(KPerBlock, QuantGroupSize::kK);
static constexpr index_t QScalesPerWarpGemmRow =
(WarpGemm::kK + kQuantGroupSize - 1) / kQuantGroupSize;
integer_divide_ceil(WarpGemm::kK, QuantGroupSize::kK);

static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow;

static_assert(kQuantGroupSize % WarpGemm::kK == 0,
"Error! WarpGemm::kK should be a multiple of kQuantGroupSize");
static_assert(QuantGroupSize::kK % WarpGemm::kK == 0,
"Error! WarpGemm::kK should be a multiple of QuantGroupSize");
static_assert(QScalesPerWarpGemmRow == 1,
"Error! kQuantGroupSize shouldn't be smaller than WarpGemm::kK");
"Error! QuantGroupSize shouldn't be smaller than WarpGemm::kK");
static_assert(KIterPerWarp % QScalesPerBlockRow == 0,
"Error! KItersPerWarp should be a multiple of QscalesPerBlockRow");

static_assert(KPerBlock / kQuantGroupSize > 0,
static_assert(KPerBlock / QuantGroupSize::kK > 0,
"Error! Each row of blockgemm should have a separate scale");

static_assert(MIterPerWarp * MWarp * WarpGemm::kM == MPerBlock,
Expand Down
Loading