Skip to content

Commit 8956bbd

Browse files
committed
Initial gtests for Stream-K
These changes add an initial gtest suite for the CK Tile Stream-K kernel. Currently, due to bugs in the StreamKTilePartitioner (which will be handled in a future PR), there are validation issues for certain cases which may differ on different architectures. Thus, we opted to run cases that are only fully data-parallel (skipping others). A guard was added to Stream-K's IsSupportedArgument method to ensure that callers are aware of this constraint. Additionally, to ensure testing reproducibility, options for setting the number of CUs and occupancy were added to MakeKernelArgs.
1 parent e041a08 commit 8956bbd

File tree

7 files changed

+493
-8
lines changed

7 files changed

+493
-8
lines changed

include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -141,11 +141,17 @@ struct StreamKKernel
141141
return UniversalGemmKernel::BlockSize();
142142
}
143143

144-
CK_TILE_HOST static StreamKKernelArgs MakeKernelArgs(const StreamKHostArgs& host_args)
144+
/// @brief Constructs kernel arguments for the Stream-K kernel.
145+
/// @param host_args Stream-K host arguments.
146+
/// @param num_cu Number of compute units (CUs). The default is the number of CUs on the device.
147+
/// The caller may select their own to assist with test reproducibility, etc.
148+
/// @param occupancy The maximum number of active blocks per CU for this kernel. The caller may
149+
/// select their own to assist with test reproducibility, etc.
150+
/// @return The kernel arguments for Stream-K.
151+
CK_TILE_HOST static StreamKKernelArgs MakeKernelArgs(const StreamKHostArgs& host_args,
152+
int num_cu = NumCU(),
153+
int occupancy = Occupancy())
145154
{
146-
uint32_t occupancy = static_cast<uint32_t>(Occupancy());
147-
uint32_t num_cu = static_cast<uint32_t>(NumCU());
148-
149155
return StreamKKernelArgs{{host_args.as_ptr,
150156
host_args.bs_ptr,
151157
host_args.ds_ptr,
@@ -166,8 +172,8 @@ struct StreamKKernel
166172
TilePartitioner{static_cast<uint32_t>(host_args.M),
167173
static_cast<uint32_t>(host_args.N),
168174
static_cast<uint32_t>(host_args.K),
169-
num_cu,
170-
occupancy,
175+
static_cast<uint32_t>(num_cu),
176+
static_cast<uint32_t>(occupancy),
171177
host_args.num_sk_blocks}};
172178
}
173179

@@ -212,9 +218,17 @@ struct StreamKKernel
212218
}
213219
}
214220

215-
CK_TILE_HOST static bool
216-
IsSupportedArgument(const typename UniversalGemmKernel::KernelArgs& kargs)
221+
CK_TILE_HOST static bool IsSupportedArgument(const StreamKKernelArgs& kargs)
217222
{
223+
if(kargs.tile_partitioner.sk_num_blocks != 0)
224+
{
225+
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
226+
{
227+
CK_TILE_ERROR("CK Tile Stream-K currently only supports 0 SK blocks (i.e., "
228+
"data-parallel only).");
229+
}
230+
return false;
231+
}
218232
return UniversalGemmKernel::IsSupportedArgument(kargs);
219233
}
220234

test/ck_tile/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ add_subdirectory(gemm_weight_preshuffle)
44
add_subdirectory(batched_gemm)
55
add_subdirectory(grouped_gemm)
66
add_subdirectory(gemm_multi_d)
7+
add_subdirectory(gemm_streamk)
78
add_subdirectory(data_type)
89
add_subdirectory(container)
910
add_subdirectory(elementwise)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Currently test_ck_tile_streamk is only built on gfx9
2+
if(GPU_TARGETS MATCHES "gfx9")
3+
add_gtest_executable(test_ck_tile_streamk test_gemm_streamk.cpp)
4+
else()
5+
message(DEBUG "Skipping test_ck_tile_streamk tests for current target")
6+
endif()
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
2+
// SPDX-License-Identifier: MIT
3+
4+
#include "test_gemm_streamk_types.hpp"
5+
#include "test_gemm_streamk_util.hpp"
6+
#include "gtest/gtest.h"
7+
8+
#define TEST_SUITE_NAME TestCkTileStreamK
9+
10+
TYPED_TEST_SUITE(TestCkTileStreamK, KernelTypesStreamK);
11+
12+
#include "test_gemm_streamk_cases.inc"
13+
14+
#undef TEST_SUITE_NAME
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
2+
// SPDX-License-Identifier: MIT
3+
4+
#pragma once
5+
6+
#ifndef TEST_STREAM_K_CASES_INC
7+
#define TEST_STREAM_K_CASES_INC
8+
9+
TYPED_TEST(TEST_SUITE_NAME, StreamK_M256_N256_K256_DP)
10+
{
11+
12+
ck_tile::index_t M = 256;
13+
ck_tile::index_t N = 256;
14+
ck_tile::index_t K = 256;
15+
uint32_t num_sk_blocks = 0;
16+
17+
this->Run(M, N, K, num_sk_blocks);
18+
}
19+
20+
TYPED_TEST(TEST_SUITE_NAME, StreamK_M256_N256_K256_SKBlocks4)
21+
{
22+
23+
ck_tile::index_t M = 256;
24+
ck_tile::index_t N = 256;
25+
ck_tile::index_t K = 256;
26+
uint32_t num_sk_blocks = 4;
27+
28+
this->Run(M, N, K, num_sk_blocks);
29+
}
30+
31+
TYPED_TEST(TEST_SUITE_NAME, StreamK_M256_N256_K256_SKBlocks8)
32+
{
33+
34+
ck_tile::index_t M = 256;
35+
ck_tile::index_t N = 256;
36+
ck_tile::index_t K = 256;
37+
uint32_t num_sk_blocks = 8;
38+
39+
this->Run(M, N, K, num_sk_blocks);
40+
}
41+
42+
TYPED_TEST(TEST_SUITE_NAME, StreamK_M512_N512_K512_DP)
43+
{
44+
45+
ck_tile::index_t M = 512;
46+
ck_tile::index_t N = 512;
47+
ck_tile::index_t K = 512;
48+
uint32_t num_sk_blocks = 0;
49+
50+
this->Run(M, N, K, num_sk_blocks);
51+
}
52+
53+
TYPED_TEST(TEST_SUITE_NAME, StreamK_M512_N512_K512_SKBlocks16)
54+
{
55+
56+
ck_tile::index_t M = 512;
57+
ck_tile::index_t N = 512;
58+
ck_tile::index_t K = 512;
59+
uint32_t num_sk_blocks = 16;
60+
61+
this->Run(M, N, K, num_sk_blocks);
62+
}
63+
64+
TYPED_TEST(TEST_SUITE_NAME, StreamK_M512_N512_K512_SKBlocks8)
65+
{
66+
67+
ck_tile::index_t M = 512;
68+
ck_tile::index_t N = 512;
69+
ck_tile::index_t K = 512;
70+
uint32_t num_sk_blocks = 8;
71+
72+
this->Run(M, N, K, num_sk_blocks);
73+
}
74+
75+
TYPED_TEST(TEST_SUITE_NAME, StreamK_M3840_N4096_K4096_DP)
76+
{
77+
78+
ck_tile::index_t M = 3840;
79+
ck_tile::index_t N = 4096;
80+
ck_tile::index_t K = 4096;
81+
uint32_t num_sk_blocks = 0;
82+
83+
this->Run(M, N, K, num_sk_blocks);
84+
}
85+
86+
TYPED_TEST(TEST_SUITE_NAME, StreamK_M3840_N4096_K4096_SKBlocks960)
87+
{
88+
89+
ck_tile::index_t M = 3840;
90+
ck_tile::index_t N = 4096;
91+
ck_tile::index_t K = 4096;
92+
uint32_t num_sk_blocks = 960;
93+
94+
this->Run(M, N, K, num_sk_blocks);
95+
}
96+
97+
TYPED_TEST(TEST_SUITE_NAME, StreamK_Unsupported_Reduction)
98+
{
99+
100+
ck_tile::index_t M = 3840;
101+
ck_tile::index_t N = 4096;
102+
ck_tile::index_t K = 4096;
103+
uint32_t num_sk_blocks = 64;
104+
105+
EXPECT_THROW(this->Run(M, N, K, num_sk_blocks, ck_tile::StreamKReductionStrategy::Reduction),
106+
std::runtime_error);
107+
}
108+
109+
#endif
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
2+
// SPDX-License-Identifier: MIT
3+
4+
#include <tuple>
5+
#include <type_traits>
6+
7+
#include "gtest/gtest.h"
8+
9+
#include "ck_tile/host.hpp"
10+
11+
using F16 = ck_tile::half_t;
12+
using F32 = float;
13+
using BF16 = ck_tile::bf16_t;
14+
15+
using Row = ck_tile::tensor_layout::gemm::RowMajor;
16+
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
17+
18+
// clang-format off
19+
using KernelTypesStreamK = ::testing::Types<
20+
// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType
21+
std::tuple< Row, Col, Row, F16, F16, F32, F16>,
22+
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16>
23+
>;
24+
25+
// clang-format on

0 commit comments

Comments
 (0)