diff --git a/experimental/builder/include/ck_tile/builder/conv_factory.hpp b/experimental/builder/include/ck_tile/builder/conv_factory.hpp index de8ba4f648..31be8c322c 100644 --- a/experimental/builder/include/ck_tile/builder/conv_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_factory.hpp @@ -158,6 +158,28 @@ struct ConvTensorLayouts +consteval auto GetTensorLayout() +{ + + if constexpr(SPATIAL_DIM == 1) + { + return factory_internal::ConvTensorLayouts{}; + } + else if constexpr(SPATIAL_DIM == 2) + { + return factory_internal::ConvTensorLayouts{}; + } + else if constexpr(SPATIAL_DIM == 3) + { + return factory_internal::ConvTensorLayouts{}; + } + else + { + static_assert(false, "Unsupported spatial dimension for convolution layout."); + } +} + // Type mappings from builder convolution data type to CK tensor types. template struct ConvTensorTypes @@ -432,16 +454,19 @@ template struct ConvFactory; -// Factory specialization for an instance of a grouped forward convolution kernel. +// Factory specialization for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 instance +// of a grouped forward convolution kernel. template - requires ConvDirectionIsForward + requires ConvDirectionIsForward && + ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 struct ConvFactory { static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; - using Layouts = - factory_internal::ConvTensorLayouts; + using Layouts = decltype(factory_internal::GetTensorLayout()); using Types = factory_internal::ConvTensorTypes; using Ops = factory_internal::ElementwiseOps; using AlgorithmType = decltype(ALGORITHM); diff --git a/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp index 0851f0061e..370e7b6521 100644 --- a/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp @@ -21,6 +21,7 @@ #include #include "ck_tile/builder/types.hpp" +#include "ck_tile/builder/conv_signature_predicates.hpp" namespace ck_tile::builder { @@ -40,16 +41,21 @@ template concept ConvDataType = (T == DataType::FP32) || (T == DataType::FP16) || (T == DataType::BF16) || (T == DataType::FP8) || (T == DataType::I8) || (T == DataType::U8); +template +concept ConvDeviceOp = std::same_as, GroupConvDeviceOp>; + +template +concept ConvLayout = std::same_as, GroupConvLayout>; + // Concept for a type that defines a convolution's operational signature. template concept ConvSignatureDescriptor = requires(T t) { { t.spatial_dim } -> std::convertible_to; { t.direction } -> std::convertible_to; - requires std::convertible_to || - std::convertible_to || - std::convertible_to; + { t.layout } -> ConvLayout; { t.data_type } -> std::convertible_to; { t.elementwise_operation } -> std::convertible_to; + { t.device_operation } -> ConvDeviceOp; }; // Concept to validate a convolution signature's values. @@ -57,18 +63,7 @@ template concept ValidConvSignature = requires { requires ConvSpatialDim; requires ConvDataType; + requires IsValidConvDeviceOp; }; -// Predicate for forward convolution. -template -concept ConvDirectionIsForward = (Sig.direction == ConvDirection::FORWARD); - -// Predicate for backward data convolution. -template -concept ConvDirectionIsBackwardData = (Sig.direction == ConvDirection::BACKWARD_DATA); - -// Predicate for backward weight convolution. -template -concept ConvDirectionIsBackwardWeight = (Sig.direction == ConvDirection::BACKWARD_WEIGHT); - } // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/conv_signature_predicates.hpp b/experimental/builder/include/ck_tile/builder/conv_signature_predicates.hpp new file mode 100644 index 0000000000..f947c7e329 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/conv_signature_predicates.hpp @@ -0,0 +1,174 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/builder/types.hpp" + +namespace ck_tile::builder { + +/********************************************** + * Conv Direction Predicates + **********************************************/ + +// Predicate for forward convolution. +template +concept ConvDirectionIsForward = (Sig.direction == ConvDirection::FORWARD); + +// Predicate for backward data convolution. +template +concept ConvDirectionIsBackwardData = (Sig.direction == ConvDirection::BACKWARD_DATA); + +// Predicate for backward weight convolution. +template +concept ConvDirectionIsBackwardWeight = (Sig.direction == ConvDirection::BACKWARD_WEIGHT); + +/********************************************** + * Conv Fwd Device Op Predicates + **********************************************/ + +// Predicate for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 = + (Sig.device_operation._fwd == + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3); + +// Predicate for DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK = + (Sig.device_operation._fwd == + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK); + +// Predicate for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle = + (Sig.device_operation._fwd == + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle); + +// Predicate for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = + (Sig.device_operation._fwd == + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle); + +// Predicate for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor = + (Sig.device_operation._fwd == + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor); + +// Generic predicate to check if signature uses any forward convolution device operation. +template +concept ConvDeviceOpIsForward = + ConvDeviceOpIs_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK || + ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle || + ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle || + ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 || + ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor; + +/********************************************** + * Conv Bwd Weight Device Op Predicates + **********************************************/ + +// Predicate for DeviceGroupedConvBwdWeight operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight = + (Sig.device_operation._bwd_weight == + BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight); + +// Predicate for DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle = + (Sig.device_operation._bwd_weight == + BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle); + +// Predicate for DeviceGroupedConvBwdWeight_Xdl_CShuffle operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffle = + (Sig.device_operation._bwd_weight == + BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Xdl_CShuffle); + +// Predicate for DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle = + (Sig.device_operation._bwd_weight == + BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle); + +// Predicate for DeviceGroupedConvBwdWeight_Wmma_CShuffle operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Wmma_CShuffle = + (Sig.device_operation._bwd_weight == + BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Wmma_CShuffle); + +// Predicate for DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 = + (Sig.device_operation._bwd_weight == + BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Xdl_CShuffleV3); + +// Predicate for DeviceGroupedConvBwdWeightMultipleD operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD = + (Sig.device_operation._bwd_weight == + BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightMultipleD); + +// Predicate for DeviceGroupedConvBwdWeight_Dl operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Dl = + (Sig.device_operation._bwd_weight == + BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Dl); + +// Generic predicate to check if signature uses any backward weight convolution device operation. +template +concept ConvDeviceOpIsBackwardWeight = + ConvDeviceOpIs_DeviceGroupedConvBwdWeight || + ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle || + ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffle || + ConvDeviceOpIs_DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle || + ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Wmma_CShuffle || + ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 || + ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD || + ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Dl; + +/********************************************** + * Conv Bwd Data Device Op Predicates + **********************************************/ + +// Predicate for DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 = + (Sig.device_operation._bwd_data == + BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1); + +// Predicate for DeviceGroupedConvBwdDataMultipleD operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD = + (Sig.device_operation._bwd_data == + BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD); + +// Predicate for DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle operation. +template +concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle = + (Sig.device_operation._bwd_data == + BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle); + +// Generic predicate to check if signature uses any backward data convolution device operation. +template +concept ConvDeviceOpIsBackwardData = + ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 || + ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD || + ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle; + +/********************************************** + * Generic Device Op Predicates + **********************************************/ + +// Generic predicate to check if signature uses any device operation. +template +concept IsValidConvDeviceOp = ConvDeviceOpIsForward || ConvDeviceOpIsBackwardData || + ConvDeviceOpIsBackwardWeight; + +} // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp index 21201b8d50..9ab827e3a5 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -69,7 +69,8 @@ template + typename BComputeDataType, + bool DirectLoad> struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3; } // namespace ck::tensor_operation::device @@ -124,7 +125,8 @@ template + typename BComputeDataType_, + bool DirectLoad> struct InstanceTraits> + BComputeDataType_, + DirectLoad>> { // Spatial dimension static constexpr int kSpatialDim = NDimSpatial; @@ -336,6 +339,7 @@ struct InstanceTraits(); // 47. AComputeDataType oss << "," << detail::type_name(); // 48. BComputeDataType + oss << "," << (DirectLoad ? "true" : "false"); // 49. DirectLoad oss << ">"; return oss.str(); diff --git a/experimental/builder/include/ck_tile/builder/types.hpp b/experimental/builder/include/ck_tile/builder/types.hpp index 7f49e77f81..47bd8327d4 100644 --- a/experimental/builder/include/ck_tile/builder/types.hpp +++ b/experimental/builder/include/ck_tile/builder/types.hpp @@ -48,6 +48,20 @@ enum class GroupConvLayout3D NGCDHW_GKCZYX_NGKDHW, }; +struct GroupConvLayout +{ + union + { + GroupConvLayout1D _1d; + GroupConvLayout2D _2d; + GroupConvLayout3D _3d; + }; + + constexpr GroupConvLayout(GroupConvLayout1D layout) : _1d(layout) {} + constexpr GroupConvLayout(GroupConvLayout2D layout) : _2d(layout) {} + constexpr GroupConvLayout(GroupConvLayout3D layout) : _3d(layout) {} +}; + // Direction of the convolution operation. enum class ConvDirection { @@ -56,6 +70,52 @@ enum class ConvDirection BACKWARD_WEIGHT }; +// Forward convolution device operations. +enum class FwdGroupConvDeviceOperation +{ + DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor +}; + +// Backward data convolution device operations. +enum class BwdDataGroupConvDeviceOperation +{ + DeviceGroupedConvBwdDataMultipleD, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 +}; + +// Backward weight convolution device operations. +enum class BwdWeightGroupConvDeviceOperation +{ + DeviceGroupedConvBwdWeight, + DeviceGroupedConvBwdWeight_Dl, + DeviceGroupedConvBwdWeight_Xdl_CShuffle, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3, + DeviceGroupedConvBwdWeight_Wmma_CShuffle, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle, + DeviceGroupedConvBwdWeightMultipleD, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle, +}; + +// Structural type for device operation +struct GroupConvDeviceOp +{ + union + { + FwdGroupConvDeviceOperation _fwd; + BwdDataGroupConvDeviceOperation _bwd_data; + BwdWeightGroupConvDeviceOperation _bwd_weight; + }; + + constexpr GroupConvDeviceOp(FwdGroupConvDeviceOperation op) : _fwd(op) {} + constexpr GroupConvDeviceOp(BwdDataGroupConvDeviceOperation op) : _bwd_data(op) {} + constexpr GroupConvDeviceOp(BwdWeightGroupConvDeviceOperation op) : _bwd_weight(op) {} +}; + // Fused element-wise operations. enum class ElementwiseOperation { diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index b7adbc116a..c53ce6210a 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -19,7 +19,7 @@ endfunction() # The test_conv_builder target has all the unit tests (each test should run < 10 ms) add_ck_builder_test(test_conv_builder test_conv_builder.cpp - test_instance_traits.cpp + test_fwd_instance_traits.cpp test_instance_traits_util.cpp) add_ck_builder_test(test_inline_diff test_inline_diff.cpp) diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp index d5b8802896..77ff0fe28f 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp @@ -9,12 +9,14 @@ namespace ck_tile::builder::testing { TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_1D_BF16_ChannelsFirst_scale) { - constexpr ConvSignature FwdConvSignature{ + constexpr ConvSignature FwdConvSignature{ .spatial_dim = 1, .direction = ConvDirection::FORWARD, .layout = GroupConvLayout1D::NGCW_GKXC_NGKW, .data_type = DataType::BF16, - .elementwise_operation = ElementwiseOperation::SCALE}; + .elementwise_operation = ElementwiseOperation::SCALE, + .device_operation = + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3}; constexpr ThreadBlock FwdThreadBlock{.block_size = 256, .tile_size = {.m = 256, .n = 256, .k = 32}}; diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp index 77c5c80489..5be7d5e604 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp @@ -8,12 +8,14 @@ namespace ck_tile::builder::testing { TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_BF16_ChannelsLast) { - constexpr ConvSignature FwdConvSignature{ + constexpr ConvSignature FwdConvSignature{ .spatial_dim = 2, .direction = ConvDirection::FORWARD, .layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK, .data_type = DataType::BF16, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH}; + .elementwise_operation = ElementwiseOperation::PASS_THROUGH, + .device_operation = + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3}; constexpr ThreadBlock FwdThreadBlock{.block_size = 256, .tile_size = {.m = 256, .n = 256, .k = 32}}; @@ -28,12 +30,14 @@ TEST(FwdConvInstances, TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_BF16_NHWGC_Filter3x3) { - constexpr ConvSignature FwdConvSignature{ + constexpr ConvSignature FwdConvSignature{ .spatial_dim = 2, .direction = ConvDirection::FORWARD, .layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK, .data_type = DataType::BF16, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH}; + .elementwise_operation = ElementwiseOperation::PASS_THROUGH, + .device_operation = + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3}; constexpr ThreadBlock FwdThreadBlock{.block_size = 256, .tile_size = {.m = 256, .n = 256, .k = 32}}; diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp index c81d7543bb..4abe3df40d 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp @@ -7,12 +7,14 @@ namespace ck_tile::builder::testing { TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP16_GNHWC) { - constexpr ConvSignature FwdConvSignature{ + constexpr ConvSignature FwdConvSignature{ .spatial_dim = 2, .direction = ConvDirection::FORWARD, .layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK, .data_type = DataType::FP16, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH}; + .elementwise_operation = ElementwiseOperation::PASS_THROUGH, + .device_operation = + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3}; constexpr ThreadBlock FwdThreadBlock{.block_size = 256, .tile_size = {.m = 256, .n = 256, .k = 32}}; diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp index d55a120bb8..5ea804cf8b 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp @@ -7,12 +7,14 @@ namespace ck_tile::builder::testing { TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP32_NGCHW_GKCYX) { - constexpr ConvSignature FwdConvSignature{ + constexpr ConvSignature FwdConvSignature{ .spatial_dim = 2, .direction = ConvDirection::FORWARD, .layout = GroupConvLayout2D::NGCHW_GKCYX_NGKHW, .data_type = DataType::FP32, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH}; + .elementwise_operation = ElementwiseOperation::PASS_THROUGH, + .device_operation = + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3}; constexpr ThreadBlock FwdThreadBlock{.block_size = 256, .tile_size = {.m = 128, .n = 128, .k = 32}}; diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp index f7bcf49e54..c729148346 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp @@ -8,12 +8,14 @@ namespace ck_tile::builder::testing { TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_BF16_GNDHWC) { - constexpr ConvSignature FwdConvSignature{ + constexpr ConvSignature FwdConvSignature{ .spatial_dim = 3, .direction = ConvDirection::FORWARD, .layout = GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK, .data_type = DataType::BF16, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH}; + .elementwise_operation = ElementwiseOperation::PASS_THROUGH, + .device_operation = + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3}; constexpr ThreadBlock FwdThreadBlock{.block_size = 256, .tile_size = {.m = 256, .n = 256, .k = 32}}; diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp index 27b5ddc821..832acd7412 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp @@ -8,12 +8,14 @@ namespace ck_tile::builder::testing { TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP16_NDHWGC_ChannelsLast) { - constexpr ConvSignature FwdConvSignature{ + constexpr ConvSignature FwdConvSignature{ .spatial_dim = 3, .direction = ConvDirection::FORWARD, .layout = GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK, .data_type = DataType::FP16, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH}; + .elementwise_operation = ElementwiseOperation::PASS_THROUGH, + .device_operation = + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3}; constexpr ThreadBlock FwdThreadBlock{.block_size = 256, .tile_size = {.m = 128, .n = 128, .k = 32}}; diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp index c0b6f04383..9d0e107dbc 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp @@ -8,12 +8,14 @@ namespace ck_tile::builder::testing { TEST(FwdConvInstances, Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP32_ChannelsFirst) { - constexpr ConvSignature FwdConvSignature{ + constexpr ConvSignature FwdConvSignature{ .spatial_dim = 3, .direction = ConvDirection::FORWARD, .layout = GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW, .data_type = DataType::FP32, - .elementwise_operation = ElementwiseOperation::PASS_THROUGH}; + .elementwise_operation = ElementwiseOperation::PASS_THROUGH, + .device_operation = + FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3}; constexpr ThreadBlock FwdThreadBlock{.block_size = 256, .tile_size = {.m = 128, .n = 128, .k = 32}}; diff --git a/experimental/builder/test/impl/conv_signature_types.hpp b/experimental/builder/test/impl/conv_signature_types.hpp index 297f827395..cc5490c711 100644 --- a/experimental/builder/test/impl/conv_signature_types.hpp +++ b/experimental/builder/test/impl/conv_signature_types.hpp @@ -3,11 +3,13 @@ #pragma once +#include #include "ck_tile/builder/conv_signature_concepts.hpp" namespace ck_tile::builder::test { -template +using namespace ck_tile::builder; + struct ConvSignature { int spatial_dim; @@ -15,9 +17,8 @@ struct ConvSignature GroupConvLayout layout; DataType data_type; ElementwiseOperation elementwise_operation; + GroupConvDeviceOp device_operation; }; -static_assert(ConvSignatureDescriptor>); -static_assert(ConvSignatureDescriptor>); -static_assert(ConvSignatureDescriptor>); +static_assert(ConvSignatureDescriptor); } // namespace ck_tile::builder::test diff --git a/experimental/builder/test/utils/ckb_conv_test_common.hpp b/experimental/builder/test/utils/ckb_conv_test_common.hpp index 7ad01bd922..cd3943d26f 100644 --- a/experimental/builder/test/utils/ckb_conv_test_common.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_common.hpp @@ -11,7 +11,7 @@ using namespace ck_tile::builder; using namespace test; // Common test implementation -template