Skip to content

Commit cafaeb6

Browse files
authored
Add instance traits for two more grouped forward convolutions (#3112)
1 parent 121bf0e commit cafaeb6

11 files changed

+1193
-13
lines changed

experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp

Lines changed: 350 additions & 0 deletions
Large diffs are not rendered by default.

experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp

Lines changed: 344 additions & 0 deletions
Large diffs are not rendered by default.

experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <ck/utility/data_type.hpp>
1616
#include <ck/utility/sequence.hpp>
1717
#include <ck/utility/blkgemmpipe_scheduler.hpp>
18+
#include <ck/utility/loop_scheduler.hpp>
1819
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
1920
#include <ck_tile/ops/common/tensor_layout.hpp>
2021
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
@@ -160,6 +161,17 @@ constexpr std::string_view pipeline_version_name(ck::BlockGemmPipelineVersion ve
160161
}
161162
}
162163

164+
// Convert LoopScheduler enum to string
165+
constexpr std::string_view loop_scheduler_name(ck::LoopScheduler sched)
166+
{
167+
using enum ck::LoopScheduler;
168+
switch(sched)
169+
{
170+
case Default: return "Default";
171+
case Interwave: return "Interwave";
172+
}
173+
}
174+
163175
// Convert std::array to string
164176
template <typename T, std::size_t N>
165177
inline std::string array_to_string(const std::array<T, N>& arr)

experimental/builder/test/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ add_ck_builder_test(test_inline_diff test_inline_diff.cpp)
2626

2727
# Testing the virtual GetInstanceString methods requires kernel compilation.
2828
add_ck_builder_test(test_get_instance_string
29-
test_get_instance_string.cpp)
29+
test_get_instance_string_fwd_grp_conv_v3.cpp
30+
test_get_instance_string_fwd_grp_conv.cpp
31+
test_get_instance_string_fwd_grp_conv_large_tensor.cpp)
3032

3133
# Testing the fwd convolution builder requires kernel compilation.
3234
# To enable parallel compilation, the individual tests are split into separate files.

experimental/builder/test/test_instance_traits.cpp renamed to experimental/builder/test/test_fwd_instance_traits.cpp

Lines changed: 234 additions & 10 deletions
Large diffs are not rendered by default.
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
// SPDX-License-Identifier: MIT
2+
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3+
4+
#include <gtest/gtest.h>
5+
#include <ck_tile/builder/reflect/instance_traits.hpp>
6+
#include <ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp>
7+
#include <ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp>
8+
9+
// Test GetInstanceString through base class pointer for non-V3 variant
10+
TEST(GetInstanceString, ReturnsStringForFwdGrpConvInstance)
11+
{
12+
// Use the template helper to get a working instance configuration
13+
using InstanceTuple =
14+
ck::tensor_operation::device::instance::device_grouped_conv_fwd_xdl_f16_instances<
15+
2, // NDimSpatial
16+
ck::tensor_operation::device::instance::GNHWC, // ALayout
17+
ck::tensor_operation::device::instance::GKYXC, // BLayout
18+
ck::tensor_operation::device::instance::Empty_Tuple, // DsLayout
19+
ck::tensor_operation::device::instance::GNHWK, // ELayout
20+
ck::tensor_operation::device::instance::ConvFwdDefault>; // ConvForwardSpecialization
21+
22+
// Get the first instance from the tuple
23+
using DeviceInstance = typename std::tuple_element<0, InstanceTuple>::type;
24+
25+
// Define the base class type using DeviceGroupedConvFwdMultipleABD
26+
using BaseClass = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
27+
2, // NDimSpatial
28+
ck::tensor_operation::device::instance::GNHWC, // ALayout
29+
ck::tensor_operation::device::instance::GKYXC, // BLayout
30+
ck::tensor_operation::device::instance::Empty_Tuple, // DsLayout
31+
ck::tensor_operation::device::instance::GNHWK, // ELayout
32+
ck::half_t, // ADataType
33+
ck::half_t, // BDataType
34+
ck::Tuple<>, // DsDataType
35+
ck::half_t, // EDataType
36+
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
37+
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
38+
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
39+
ck::half_t, // AComputeType
40+
ck::half_t>; // BComputeType
41+
42+
// Create an instance of the derived class
43+
DeviceInstance device_instance;
44+
45+
// Get a pointer to the base class
46+
BaseClass* base_ptr = &device_instance;
47+
48+
// Call GetInstanceString through the base class pointer
49+
std::string instance_str = base_ptr->GetInstanceString();
50+
51+
// Expected complete instance string based on the first instance from
52+
// device_grouped_conv_fwd_xdl_f16_instances
53+
std::string expected_str = "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle"
54+
"<2" // NDimSpatial
55+
",GNHWC" // ALayout
56+
",GKYXC" // BLayout
57+
",EmptyTuple" // DsLayout
58+
",GNHWK" // ELayout
59+
",fp16" // ADataType
60+
",fp16" // BDataType
61+
",fp32" // AccDataType
62+
",fp16" // CShuffleDataType
63+
",EmptyTuple" // DsDataType
64+
",fp16" // EDataType
65+
",PassThrough" // AElementwiseOperation
66+
",PassThrough" // BElementwiseOperation
67+
",PassThrough" // CDEElementwiseOperation
68+
",Default" // ConvForwardSpecialization
69+
",MNKPadding" // GemmSpec
70+
",1" // NumGemmKPrefetchStage
71+
",64" // BlockSize
72+
",64" // MPerBlock
73+
",64" // NPerBlock
74+
",32" // KPerBlock
75+
",8" // AK1
76+
",8" // BK1
77+
",32" // MPerXDL
78+
",32" // NPerXDL
79+
",2" // MXdlPerWave
80+
",2" // NXdlPerWave
81+
",Seq(4,16,1)" // ABlockTransferThreadClusterLengths
82+
",Seq(1,0,2)" // ABlockTransferThreadClusterArrangeOrder
83+
",Seq(1,0,2)" // ABlockTransferSrcAccessOrder
84+
",2" // ABlockTransferSrcVectorDim
85+
",1" // ABlockTransferSrcScalarPerVector
86+
",8" // ABlockTransferDstScalarPerVector_AK1
87+
",1" // ABlockLdsExtraM
88+
",Seq(4,16,1)" // BBlockTransferThreadClusterLengths
89+
",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder
90+
",Seq(1,0,2)" // BBlockTransferSrcAccessOrder
91+
",2" // BBlockTransferSrcVectorDim
92+
",1" // BBlockTransferSrcScalarPerVector
93+
",8" // BBlockTransferDstScalarPerVector_BK1
94+
",1" // BBlockLdsExtraN
95+
",1" // CShuffleMXdlPerWavePerShuffle
96+
",1" // CShuffleNXdlPerWavePerShuffle
97+
",Seq(1,16,1,4)" // CDEBlockTransferClusterLengths
98+
",1" // CDEBlockTransferScalarPerVector_NPerBlock
99+
",fp16" // AComputeDataType
100+
",fp16" // BComputeDataType
101+
",Default" // LoopScheduler
102+
",1>"; // NumGroupsToMerge
103+
EXPECT_EQ(instance_str, expected_str);
104+
}
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
// SPDX-License-Identifier: MIT
2+
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3+
4+
#include <gtest/gtest.h>
5+
#include <ck_tile/builder/reflect/instance_traits.hpp>
6+
#include <ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp>
7+
#include <ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp>
8+
9+
// Test GetInstanceString through base class pointer for large tensor variant
10+
TEST(GetInstanceString, ReturnsStringForFwdGrpConvLargeTensorInstance)
11+
{
12+
// Use the template helper to get a working instance configuration
13+
using InstanceTuple = ck::tensor_operation::device::instance::
14+
device_grouped_conv_fwd_xdl_large_tensor_f16_instances<
15+
2, // NDimSpatial
16+
ck::tensor_operation::device::instance::GNHWC, // ALayout
17+
ck::tensor_operation::device::instance::GKYXC, // BLayout
18+
ck::tensor_operation::device::instance::Empty_Tuple, // DsLayout
19+
ck::tensor_operation::device::instance::GNHWK, // ELayout
20+
ck::tensor_operation::device::instance::ConvFwdDefault>; // ConvForwardSpecialization
21+
22+
// Get the first instance from the tuple
23+
using DeviceInstance = typename std::tuple_element<0, InstanceTuple>::type;
24+
25+
// Define the base class type using DeviceGroupedConvFwdMultipleABD
26+
using BaseClass = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
27+
2, // NDimSpatial
28+
ck::tensor_operation::device::instance::GNHWC, // ALayout
29+
ck::tensor_operation::device::instance::GKYXC, // BLayout
30+
ck::tensor_operation::device::instance::Empty_Tuple, // DsLayout
31+
ck::tensor_operation::device::instance::GNHWK, // ELayout
32+
ck::half_t, // ADataType
33+
ck::half_t, // BDataType
34+
ck::Tuple<>, // DsDataType
35+
ck::half_t, // EDataType
36+
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
37+
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
38+
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
39+
ck::half_t, // AComputeType
40+
ck::half_t>; // BComputeType
41+
42+
// Create an instance of the derived class
43+
DeviceInstance device_instance;
44+
45+
// Get a pointer to the base class
46+
BaseClass* base_ptr = &device_instance;
47+
48+
// Call GetInstanceString through the base class pointer
49+
std::string instance_str = base_ptr->GetInstanceString();
50+
51+
// Expected complete instance string based on the first instance from
52+
// device_grouped_conv_fwd_xdl_large_tensor_f16_instances
53+
std::string expected_str = "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor"
54+
"<2" // NDimSpatial
55+
",GNHWC" // ALayout
56+
",GKYXC" // BLayout
57+
",EmptyTuple" // DsLayout
58+
",GNHWK" // ELayout
59+
",fp16" // ADataType
60+
",fp16" // BDataType
61+
",fp32" // AccDataType
62+
",fp16" // CShuffleDataType
63+
",EmptyTuple" // DsDataType
64+
",fp16" // EDataType
65+
",PassThrough" // AElementwiseOperation
66+
",PassThrough" // BElementwiseOperation
67+
",PassThrough" // CDEElementwiseOperation
68+
",Default" // ConvForwardSpecialization
69+
",MNKPadding" // GemmSpec
70+
",1" // NumGemmKPrefetchStage
71+
",64" // BlockSize
72+
",64" // MPerBlock
73+
",64" // NPerBlock
74+
",32" // KPerBlock
75+
",8" // AK1
76+
",8" // BK1
77+
",32" // MPerXDL
78+
",32" // NPerXDL
79+
",2" // MXdlPerWave
80+
",2" // NXdlPerWave
81+
",Seq(4,16,1)" // ABlockTransferThreadClusterLengths
82+
",Seq(1,0,2)" // ABlockTransferThreadClusterArrangeOrder
83+
",Seq(1,0,2)" // ABlockTransferSrcAccessOrder
84+
",2" // ABlockTransferSrcVectorDim
85+
",1" // ABlockTransferSrcScalarPerVector
86+
",8" // ABlockTransferDstScalarPerVector_AK1
87+
",1" // ABlockLdsExtraM
88+
",Seq(4,16,1)" // BBlockTransferThreadClusterLengths
89+
",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder
90+
",Seq(1,0,2)" // BBlockTransferSrcAccessOrder
91+
",2" // BBlockTransferSrcVectorDim
92+
",1" // BBlockTransferSrcScalarPerVector
93+
",8" // BBlockTransferDstScalarPerVector_BK1
94+
",1" // BBlockLdsExtraN
95+
",1" // CShuffleMXdlPerWavePerShuffle
96+
",1" // CShuffleNXdlPerWavePerShuffle
97+
",Seq(1,16,1,4)" // CDEBlockTransferClusterLengths
98+
",1" // CDEBlockTransferScalarPerVector_NPerBlock
99+
",fp16" // AComputeDataType
100+
",fp16" // BComputeDataType
101+
",Default>"; // LoopScheduler
102+
EXPECT_EQ(instance_str, expected_str);
103+
}

experimental/builder/test/test_get_instance_string.cpp renamed to experimental/builder/test/test_get_instance_string_fwd_grp_conv_v3.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
#include <ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp>
77
#include <ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp>
88

9-
// Test GetInstanceString through base class pointer
10-
TEST(GetInstanceStringTest, GetInstanceStringThroughBaseClass)
9+
// Test GetInstanceString through base class pointer for V3 variant
10+
TEST(GetInstanceString, ReturnsStringForFwdGrpConvV3Instance)
1111
{
1212
// Use the template helper to get a working instance configuration
1313
using InstanceTuple =

experimental/builder/test/test_instance_traits_util.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,14 @@ TEST(InstanceTraitsUtil, PipelineVersionNameReturnsCorrectStrings)
199199
ElementsAre("v1", "v2", "v3", "v4", "v5"));
200200
}
201201

202+
TEST(InstanceTraitsUtil, LoopSchedulerNameReturnsCorrectStrings)
203+
{
204+
using enum ck::LoopScheduler;
205+
EXPECT_THAT(std::vector<std::string_view> names = {loop_scheduler_name(Default),
206+
loop_scheduler_name(Interwave)},
207+
ElementsAre("Default", "Interwave"));
208+
}
209+
202210
TEST(InstanceTraitsUtil, TupleNameReturnsEmptyTupleForEmptyTuple)
203211
{
204212
EXPECT_EQ(tuple_name<ck::Tuple<>>(), "EmptyTuple");

include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828
#include "ck/host_utility/device_prop.hpp"
2929
#include "ck/host_utility/kernel_launch.hpp"
3030
#include "ck/host_utility/io.hpp"
31+
#ifdef CK_EXPERIMENTAL_BUILDER
32+
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
33+
#endif
3134

3235
namespace ck {
3336
namespace tensor_operation {
@@ -2063,6 +2066,19 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
20632066
return str.str();
20642067
}
20652068

2069+
#ifdef CK_EXPERIMENTAL_BUILDER
2070+
std::string GetInstanceString() const override
2071+
{
2072+
static_assert(ck_tile::reflect::HasInstanceTraits<DeviceOp>,
2073+
"Specialization of instance_traits not found. Please check that a "
2074+
"specialization exists in file "
2075+
"ck_tile/builder/reflect/"
2076+
"instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp "
2077+
"for the given template parameters.");
2078+
return ck_tile::reflect::instance_string<DeviceOp>();
2079+
}
2080+
#endif
2081+
20662082
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
20672083
{
20682084
auto arg = dynamic_cast<const Argument*>(p_arg);

0 commit comments

Comments
 (0)