diff --git a/CHANGELOG.md b/CHANGELOG.md index 94a2b279bc..213631721f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added WMMA (gfx12) support for FMHA. * Added pooling kernel in CK_TILE * Added top-k sigmoid kernel in CK_TILE +* Added the blockscale 2D support for CK_TILE GEMM. ### Changed diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp index edde59081c..b22596537f 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp @@ -1,6 +1,11 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +// This example demonstrates 2D block scale quantization (N×K) for BQuant +// using non-preshuffled configuration. +// NOTE: Once more 2d support is ready, we can migrate all 2d quant types to this example +// This is currently done separately to avoid too verbose dispatching. + #include #include #include @@ -17,7 +22,7 @@ template float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s) @@ -57,11 +62,12 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str GemmTraits, ComputeDataType>; + // This example only supports BQuant (no AQuant) + // For non-preshuffled BQuant, use BaseBQuantGemmPipelineAgBgCrCompV3 using BaseGemmPipeline = std::conditional_t< GemmConfig::PreshuffleB == true, ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2, - ck_tile::BaseAQuantGemmPipelineAgBgCrMem>; // memory pipeline hardcoded - // for aquant + ck_tile::BaseBQuantGemmPipelineAgBgCrCompV3>; const ck_tile::index_t K_split = (args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile; @@ -229,7 +235,7 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str template int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) { @@ -266,6 +272,41 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a return 0; } +// Forward declaration for dispatch function +template