Skip to content

Commit c6010f2

Browse files
samremesmsaffari-amdspolifroni-amd
authored
[CK_TILE] Row/Col quant gemm (#2729)
* Add cshuffle epilogue test * add the poc implementation to the epilogue and tests * refactor cshuffle epilogue * WIP: adding tensor/tile usage to scale_tile * fix usage of tile_elementwise_inout * add gemm_quant_kernel for generalizing gemm quant kernel * Add problem specific to different quants, add QuantType to Traits * Add quant_type to quant_kernel template parameters * Create aq/bq_block_windows and views depending on QuantType * Use tile windows as inputs in cshuffle epilogue * Fix some issues in epilogue * initial new example code for new general gemm quant kernel test * Fix issues in kernel * Add verification check for rowcol Quantmode * use AccDataType instead of AQ in pipeline * fix aquant preshuffle * fix formatting * some cleanup * remove gemm_aquant_basic.cpp * remove gemm_aquant_kernel.hpp * fix tests for the renamed quant kernel * fix formatting * clean example files * fix some merge conflicts * fix preshufflequant rename issue * fix some templates after merging with develop * fix test preshuffle parameter * fix formatting * Unify bquant kernel to the common quant kernel * remove bquant kernel also from common header * fix formatting * clean up commented code * fix formatting config hpp * fix merge mistake * Non-const for movable windows * fix formatting * Fix grammar in README Co-authored-by: spolifroni-amd <[email protected]> * Remove #include<bit> and clean up example * fix strides * Add some descriptions for move_windows --------- Co-authored-by: Mohsen Saffari <[email protected]> Co-authored-by: spolifroni-amd <[email protected]>
1 parent 7330ec3 commit c6010f2

23 files changed

+1838
-1332
lines changed

example/ck_tile/38_block_scale_gemm/CMakeLists.txt

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,12 @@ endif()
66
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)
77

88
if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
9-
add_executable(tile_example_gemm_aquant_basic EXCLUDE_FROM_ALL gemm_aquant_basic.cpp)
10-
target_compile_options(tile_example_gemm_aquant_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
9+
add_executable(tile_example_gemm_quant_basic EXCLUDE_FROM_ALL gemm_quant_basic.cpp)
10+
target_compile_options(tile_example_gemm_quant_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
11+
12+
add_executable(tile_example_gemm_aquant_preshuffle EXCLUDE_FROM_ALL gemm_aquant_preshuffle.cpp)
13+
target_compile_options(tile_example_gemm_aquant_preshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
14+
1115
add_executable(tile_example_gemm_bquant_basic EXCLUDE_FROM_ALL gemm_bquant_basic.cpp)
1216
target_compile_options(tile_example_gemm_bquant_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
1317
else()
Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
1-
# GEMM Matrix Multiplication
1+
# Quant GEMM Matrix Multiplication
22

3-
This folder contains example for Block Scale GEMM using ck_tile tile-programming implementation.
3+
This folder contains examples of quant GEMMs using the ck_tile tile-programming implementation.
4+
5+
- AQuant kernel with blocks of A matrix sharing scales: custom GEMM pipeline
6+
- Row and Column-wise scaled: scaling implemented in Epilogue
47

58
## build
69
```
710
# in the root of ck_tile
811
mkdir build && cd build
9-
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
12+
# you can replace <arch> with the appropriate architecture (for example gfx942) or leave it blank
1013
../script/cmake-ck-dev.sh ../ <arch>
11-
# The aquant pipeline method on the gemm calculation
12-
make tile_example_gemm_aquant_basic -j
14+
# Compile the quant kernels
15+
make tile_example_gemm_quant_basic -j
1316
make tile_example_gemm_bquant_basic -j
1417
```
15-
This will result in an executable `build/bin/tile_example_gemm_aquant_basic`
18+
This will result in an executable `build/bin/tile_example_gemm_quant_basic`
1619

1720
## example
1821
```
@@ -22,15 +25,16 @@ args:
2225
-n n dimension (default:2048)
2326
-k k dimension (default:64)
2427
-a_layout Tensor A data layout (default: R)
25-
-b_layout Tensor B data layout (default: R)
28+
-b_layout Tensor B data layout (default: C)
2629
-c_layout Tensor C data layout (default: R)
2730
-stride_a Tensor A stride (default:0)
2831
-stride_b Tensor B stride (default:0)
2932
-stride_c Tensor C stride (default:0)
30-
-v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2)
33+
-v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:1)
3134
-e Absolute error tolerance (default:1e-5)
32-
-prec data type. fp16/bf16/fp8/bf8/int8 (default:fp16)
35+
-prec data type. fp8/bf8/i4fp8/i4bf8/i4f32fp8/i4f32bf8 (default:fp8)
3336
-warmup number of iterations before benchmark the kernel (default:10)
3437
-repeat number of iterations to benchmark the kernel (default:100)
3538
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
39+
-quant_mode Which quant method to use (aquant, rowcol)
3640
```

example/ck_tile/38_block_scale_gemm/gemm_aquant_basic.cpp

Lines changed: 0 additions & 226 deletions
This file was deleted.

example/ck_tile/38_block_scale_gemm/gemm_aquant_preshuffle.cpp

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ template <typename GemmConfig,
2121
typename BLayout,
2222
typename CLayout,
2323
uint32_t QuantGroupSize>
24-
float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s)
24+
float gemm_calc_aquant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s)
2525
{
2626
constexpr bool kPadM = false;
2727
constexpr bool kPadN = false;
@@ -50,13 +50,14 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
5050

5151
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
5252

53-
using CodegenGemmTraits = ck_tile::TileGemmAQuantTraits<kPadM,
54-
kPadN,
55-
kPadK,
56-
GemmConfig::PreshuffleQuant,
57-
ALayout,
58-
BLayout,
59-
CLayout>;
53+
using CodegenGemmTraits = ck_tile::TileGemmQuantTraits<kPadM,
54+
kPadN,
55+
kPadK,
56+
GemmConfig::PreshuffleQuant,
57+
ALayout,
58+
BLayout,
59+
CLayout,
60+
ck_tile::QuantType::AQuantGrouped>;
6061

6162
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
6263
BDataType,
@@ -109,8 +110,10 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
109110
K_Warp_Tile,
110111
transposed_warp_gemm,
111112
ck_tile::memory_operation_enum::set>>;
112-
using Kernel =
113-
ck_tile::AQuantGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
113+
using Kernel = ck_tile::QuantGemmKernel<TilePartitioner,
114+
CodegenGemmPipeline,
115+
GemmEpilogue,
116+
ck_tile::QuantType::AQuantGrouped>;
114117

115118
auto kargs = Kernel::MakeKernelArgs(args);
116119

example/ck_tile/38_block_scale_gemm/gemm_bquant_basic.cpp

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ template <typename GemmConfig,
2323
typename BLayout,
2424
typename CLayout,
2525
uint32_t QuantGroupSize>
26-
float gemm_calc_bquant(const ck_tile::BQuantGemmHostArgs& args, const ck_tile::stream_config& s)
26+
float gemm_calc_bquant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s)
2727
{
2828
constexpr bool kPadM = false;
2929
constexpr bool kPadN = false;
@@ -50,13 +50,14 @@ float gemm_calc_bquant(const ck_tile::BQuantGemmHostArgs& args, const ck_tile::s
5050

5151
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
5252

53-
using CodegenGemmTraits = ck_tile::TileGemmBQuantTraits<kPadM,
54-
kPadN,
55-
kPadK,
56-
GemmConfig::PreshuffleQuant,
57-
ALayout,
58-
BLayout,
59-
CLayout>;
53+
using CodegenGemmTraits = ck_tile::TileGemmQuantTraits<kPadM,
54+
kPadN,
55+
kPadK,
56+
GemmConfig::PreshuffleQuant,
57+
ALayout,
58+
BLayout,
59+
CLayout,
60+
ck_tile::QuantType::BQuantGrouped>;
6061

6162
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
6263
BDataType,
@@ -108,8 +109,10 @@ float gemm_calc_bquant(const ck_tile::BQuantGemmHostArgs& args, const ck_tile::s
108109
K_Warp_Tile,
109110
transposed_warp_gemm,
110111
ck_tile::memory_operation_enum::set>>;
111-
using Kernel =
112-
ck_tile::BQuantGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
112+
using Kernel = ck_tile::QuantGemmKernel<TilePartitioner,
113+
CodegenGemmPipeline,
114+
GemmEpilogue,
115+
ck_tile::QuantType::BQuantGrouped>;
113116

114117
auto kargs = Kernel::MakeKernelArgs(args);
115118

0 commit comments

Comments
 (0)