Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ struct GroupedConvFwdKernelArgs

using ConvToGemmFwdTransformer =
TransformConvFwdToGemm<GroupedConvTraitsType_::NDimSpatial,
GroupedConvTraitsType_::ConvSpecialization>;
GroupedConvTraitsType_::ConvSpecialization,
true>; // Split N enabled
static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;

template <
Expand Down Expand Up @@ -56,7 +57,7 @@ struct GroupedConvFwdKernelArgs

k_batch = args.k_batch;

GemmM = args.N_ * args.output_spatial_lengths_[0];
// GemmM will be set after Split-N calculation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you apply same approach for each constructor?

GemmN = args.K_;
GemmK = args.C_ * args.filter_spatial_lengths_[0];
GemmBatch = args.G_;
Expand Down Expand Up @@ -94,6 +95,19 @@ struct GroupedConvFwdKernelArgs
1,
std::multiplies<index_t>());
group_stride_c = args.K_;

// Initialize Split-N support fields for 1D convolution (NWGC layout)
// Get the actual split N from transformer
n_per_split = conv_to_gemm_transformer.GetN();
original_n = conv_to_gemm_transformer.GetOriginalN();
n_splits = ck_tile::integer_divide_ceil(original_n, n_per_split);

// Calculate batch strides for NWGC layout
input_batch_stride = args.C_ * args.input_spatial_lengths_[0];
output_batch_stride = args.K_ * args.output_spatial_lengths_[0];

// Update GemmM to use split N (not original N)
GemmM = n_per_split * args.output_spatial_lengths_[0];
}

template <
Expand Down Expand Up @@ -133,7 +147,7 @@ struct GroupedConvFwdKernelArgs

k_batch = args.k_batch;

GemmM = args.N_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1];
// Note: GemmM will be set after Split-N calculation
GemmN = args.K_;
GemmK = args.C_ * args.filter_spatial_lengths_[0] * args.filter_spatial_lengths_[1];
GemmBatch = args.G_;
Expand Down Expand Up @@ -171,6 +185,21 @@ struct GroupedConvFwdKernelArgs
1,
std::multiplies<index_t>());
group_stride_c = args.K_;

// Initialize Split-N support fields for 2D convolution (NHWGC layout)
// Get the actual split N from transformer
n_per_split = conv_to_gemm_transformer.GetN();
original_n = conv_to_gemm_transformer.GetOriginalN();
n_splits = ck_tile::integer_divide_ceil(original_n, n_per_split);

// Calculate batch strides for NHWGC layout
input_batch_stride =
args.C_ * args.input_spatial_lengths_[0] * args.input_spatial_lengths_[1];
output_batch_stride =
args.K_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1];

// Update GemmM to use split N (not original N)
GemmM = n_per_split * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1];
}

template <
Expand Down Expand Up @@ -217,8 +246,7 @@ struct GroupedConvFwdKernelArgs

k_batch = args.k_batch;

GemmM = args.N_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1] *
args.output_spatial_lengths_[2];
// Note: GemmM will be set after Split-N calculation
GemmN = args.K_;
GemmK = args.C_ * args.filter_spatial_lengths_[0] * args.filter_spatial_lengths_[1] *
args.filter_spatial_lengths_[2];
Expand Down Expand Up @@ -257,6 +285,22 @@ struct GroupedConvFwdKernelArgs
1,
std::multiplies<index_t>());
group_stride_c = args.K_;

// Initialize Split-N support fields for 3D convolution (NDHWGC layout)
// Get the actual split N from transformer
n_per_split = conv_to_gemm_transformer.GetN();
original_n = conv_to_gemm_transformer.GetOriginalN();
n_splits = ck_tile::integer_divide_ceil(original_n, n_per_split);

// Calculate batch strides for NDHWGC layout
input_batch_stride = args.C_ * args.input_spatial_lengths_[0] *
args.input_spatial_lengths_[1] * args.input_spatial_lengths_[2];
output_batch_stride = args.K_ * args.output_spatial_lengths_[0] *
args.output_spatial_lengths_[1] * args.output_spatial_lengths_[2];

// Update GemmM to use split N (not original N)
GemmM = n_per_split * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1] *
args.output_spatial_lengths_[2];
}

using AGridDescMK = remove_cvref_t<
Expand Down Expand Up @@ -297,6 +341,13 @@ struct GroupedConvFwdKernelArgs
long_index_t group_stride_a;
long_index_t group_stride_b;
long_index_t group_stride_c;

// Split-N support fields - initialize to safe defaults
index_t n_splits = 1; // Number of batch splits (e.g., 2 for 128→64×2)
index_t n_per_split = 1; // Batches per split (N_ from transformer)
index_t original_n = 1; // Original batch size before splitting
index_t input_batch_stride = 0; // Stride to next batch in input tensor
index_t output_batch_stride = 0; // Stride to next batch in output tensor
};

/// @brief The Grouped Convolution Forward kernel template.
Expand Down Expand Up @@ -392,10 +443,10 @@ struct GroupedConvolutionForwardKernel
// clang-format on
}

CK_TILE_HOST static constexpr auto GridSize(const GroupedConvFwdKernelArgsSpecialized& kargs)
CK_TILE_HOST static auto GridSize(const GroupedConvFwdKernelArgsSpecialized& kargs)
{
return dim3(
TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN), kargs.GemmBatch, kargs.k_batch);
TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN), kargs.GemmBatch, kargs.n_splits);
}

CK_TILE_HOST static auto BlockSize()
Expand Down Expand Up @@ -430,6 +481,17 @@ struct GroupedConvolutionForwardKernel
}
}

// Check Split-K and Split-N conflict (both use blockIdx.z)
if(kargs.k_batch > 1 && kargs.n_splits > 1)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR(
"Cannot use both Split-K and Split-N simultaneously (both use blockIdx.z)!");
}
return false;
}

const index_t ConvK = kargs.wei_g_k_c_xs_lengths[number<1>{}];
const index_t ConvC = kargs.wei_g_k_c_xs_lengths[number<2>{}];

Expand Down Expand Up @@ -768,10 +830,26 @@ struct GroupedConvolutionForwardKernel
const auto group_offset_b = __builtin_amdgcn_readfirstlane(kargs.group_stride_b * blockIdY);
const auto group_offset_c = __builtin_amdgcn_readfirstlane(kargs.group_stride_c * blockIdY);

// options
const InDataType* a_ptr = static_cast<const InDataType*>(kargs.in_ptr) + group_offset_a;
const WeiDataType* b_ptr = static_cast<const WeiDataType*>(kargs.wei_ptr) + group_offset_b;
OutDataType* c_ptr = static_cast<OutDataType*>(kargs.out_ptr) + group_offset_c;
// Split-N handling: Get which split this workgroup handles
const auto blockIdZ = __builtin_amdgcn_readfirstlane(blockIdx.z);

// Calculate batch offset for this split
const index_t batch_offset = __builtin_amdgcn_readfirstlane(blockIdZ * kargs.n_per_split);

// Calculate memory offsets for this split
const long_index_t input_batch_offset = static_cast<long_index_t>(batch_offset) *
static_cast<long_index_t>(kargs.input_batch_stride);
const long_index_t output_batch_offset =
static_cast<long_index_t>(batch_offset) *
static_cast<long_index_t>(kargs.output_batch_stride);

// Adjust pointers: combine group offset and batch offset
const InDataType* a_ptr =
static_cast<const InDataType*>(kargs.in_ptr) + group_offset_a + input_batch_offset;
const WeiDataType* b_ptr = static_cast<const WeiDataType*>(kargs.wei_ptr) +
group_offset_b; // No batch offset for weights!
OutDataType* c_ptr =
static_cast<OutDataType*>(kargs.out_ptr) + group_offset_c + output_batch_offset;

// allocate LDS
__shared__ char smem_ptr_0[GetSmemSize()];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ struct TransformConvFwdToGemm
static constexpr auto I3 = number<3>{};
static constexpr auto I4 = number<4>{};
static constexpr auto I5 = number<5>{};
#if 0 // TODO: Enable these functionalities

template <typename ConvDimsType>
static long_index_t calculate_element_space_size_impl(const ConvDimsType& lengths,
const ConvDimsType& strides,
Expand All @@ -42,24 +42,40 @@ struct TransformConvFwdToGemm

template <typename ConvDimsType>
static IndexType GetSplitedNSize(const ConvDimsType& a_g_n_c_wis_lengths,
const ConvDimsType& a_g_n_c_wis_strides,
const ConvDimsType& c_g_n_k_wos_lengths,
const ConvDimsType& c_g_n_k_wos_strides)
const ConvDimsType& c_g_n_k_wos_lengths)
{
// Calculate strides internally assuming contiguous memory layout
ConvDimsType a_g_n_c_wis_strides, c_g_n_k_wos_strides;
const index_t num_dims = a_g_n_c_wis_lengths.size();

// Calculate strides for input tensor (innermost to outermost)
a_g_n_c_wis_strides[num_dims - 1] = 1;
for(index_t i = num_dims - 2; i >= 0; i--)
{
a_g_n_c_wis_strides[i] = a_g_n_c_wis_strides[i + 1] * a_g_n_c_wis_lengths[i + 1];
}

// Calculate strides for output tensor
c_g_n_k_wos_strides[num_dims - 1] = 1;
for(index_t i = num_dims - 2; i >= 0; i--)
{
c_g_n_k_wos_strides[i] = c_g_n_k_wos_strides[i + 1] * c_g_n_k_wos_lengths[i + 1];
}

const long_index_t a_element_space_size =
calculate_element_space_size_impl(a_g_n_c_wis_lengths, a_g_n_c_wis_strides, I1);
const long_index_t c_element_space_size =
calculate_element_space_size_impl(c_g_n_k_wos_lengths, c_g_n_k_wos_strides, I1);
const long_index_t element_space_size = math::max(a_element_space_size * sizeof(ADataType),
c_element_space_size * sizeof(CDataType));
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
const long_index_t element_space_size = ck_tile::max(
a_element_space_size * sizeof(ADataType), c_element_space_size * sizeof(CDataType));
constexpr long_index_t TwoGB = (long_index_t{1} << 31); // 2GB

const IndexType N = a_g_n_c_wis_lengths[I1];

if(element_space_size > TwoGB)
{
// Minimum divisor of N to not exceed 2GB
const auto divisor = math::integer_divide_ceil(element_space_size, TwoGB);
const auto divisor = ck_tile::integer_divide_ceil(element_space_size, TwoGB);

if(divisor <= static_cast<double>(N))
{
Expand All @@ -70,7 +86,8 @@ struct TransformConvFwdToGemm
{
if(N % least_divisor == 0)
{
return N / least_divisor;
IndexType result = N / least_divisor;
return result;
}
}
// Not found, process one Convolution N per block
Expand All @@ -90,16 +107,20 @@ struct TransformConvFwdToGemm
return N;
}
}
#endif

public:
// Public getter methods for Split-N support
CK_TILE_HOST constexpr IndexType GetN() const { return N_; }
CK_TILE_HOST constexpr IndexType GetOriginalN() const { return original_N_; }

CK_TILE_HOST constexpr TransformConvFwdToGemm() {}

template <typename TransformConvFwdToGemmBase>
CK_TILE_HOST
TransformConvFwdToGemm(const TransformConvFwdToGemmBase& transform_conv_fwd_to_gemm_base)
: G_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.G_)},
N_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.N_)},
original_N_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.original_N_)},
Di_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Di_)},
Hi_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Hi_)},
Wi_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Wi_)},
Expand Down Expand Up @@ -168,18 +189,14 @@ struct TransformConvFwdToGemm
std::is_same_v<ConvSpatialDimsType, ck_tile::array<IndexType, NDimSpatial>>);
static_assert(std::is_same_v<ConvDimsType, std::array<IndexType, NDimSpatial + I3>> ||
std::is_same_v<ConvDimsType, ck_tile::array<IndexType, NDimSpatial + I3>>);
#if 0 // TODO: Enable these functionalities
if constexpr(SplitN)
{
N_ = GetSplitedNSize(
a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides);
N_ = GetSplitedNSize(a_g_n_c_wis_lengths, c_g_n_k_wos_lengths);
}
else
{
N_ = c_g_n_k_wos_lengths[I1];
}
#endif
N_ = c_g_n_k_wos_lengths[I1];
}

template <typename ConvDimsType,
Expand Down Expand Up @@ -223,18 +240,19 @@ struct TransformConvFwdToGemm
std::is_same_v<ConvSpatialDimsType, ck_tile::array<IndexType, NDimSpatial>>);
static_assert(std::is_same_v<ConvDimsType, std::array<IndexType, NDimSpatial + I3>> ||
std::is_same_v<ConvDimsType, ck_tile::array<IndexType, NDimSpatial + I3>>);
#if 0 // TODO: Enable these functionalities

// Store original N
original_N_ = c_g_n_k_wos_lengths[I1];

if constexpr(SplitN)
{
N_ = GetSplitedNSize(
a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides);
N_ = GetSplitedNSize(a_g_n_c_wis_lengths, c_g_n_k_wos_lengths);
}
else
{
N_ = c_g_n_k_wos_lengths[I1];
N_ = c_g_n_k_wos_lengths[I1];
original_N_ = N_;
}
#endif
N_ = c_g_n_k_wos_lengths[I1];
}

template <typename ConvDimsType,
Expand Down Expand Up @@ -278,18 +296,18 @@ struct TransformConvFwdToGemm
std::is_same_v<ConvSpatialDimsType, ck_tile::array<IndexType, NDimSpatial>>);
static_assert(std::is_same_v<ConvDimsType, std::array<IndexType, NDimSpatial + I3>> ||
std::is_same_v<ConvDimsType, ck_tile::array<IndexType, NDimSpatial + I3>>);
#if 0 // TODO: Enable these functionalities

// Store original N before potential splitting
original_N_ = c_g_n_k_wos_lengths[I1];

if constexpr(SplitN)
{
N_ = GetSplitedNSize(
a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides);
N_ = GetSplitedNSize(a_g_n_c_wis_lengths, c_g_n_k_wos_lengths);
}
else
{
N_ = c_g_n_k_wos_lengths[I1];
N_ = original_N_;
}
#endif
N_ = c_g_n_k_wos_lengths[I1];
}

#if 0 // TODO: Enable these functionalities
Expand Down Expand Up @@ -1417,7 +1435,7 @@ struct TransformConvFwdToGemm
}
}

IndexType G_, N_;
IndexType G_, N_, original_N_;
IndexType Di_, Hi_, Wi_;
IndexType Do_, Ho_, Wo_;
IndexType Z_, Y_, X_;
Expand Down
Loading