-
Notifications
You must be signed in to change notification settings - Fork 238
[CK Tile] Grouped conv fwd splitn support #2776
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
f189102
to
e64b01c
Compare
e64b01c
to
75589b5
Compare
The Split-N implementation remains functional in: - include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp - include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp Tests pass locally and Split-N verified working with tile_example_grouped_conv_fwd
315518e
to
950f77d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome, glad you have tested it! LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice
|
||
// Note: GemmM will be updated after Split-N calculation | ||
// Initially set to full size, will be adjusted after transformer creation | ||
GemmM = args.N_ * args.output_spatial_lengths_[0]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So what is the reason to initialize GemmM here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes're right. This initialization is redundant.
return dim3( | ||
TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN), kargs.GemmBatch, kargs.k_batch); | ||
// Ensure n_splits is at least 1 (defensive programming) | ||
const index_t grid_z = (kargs.n_splits == 0 || kargs.n_splits > 1024) ? 1 : kargs.n_splits; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why > 1024?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed this line. It looks over defensive.
- n_splits can't be 0: Default is 1, and ceiling division always returns ≥1
- 1024 is arbitrary: No technical reason for this limit, and it could block valid large batch cases
|
||
// Check if this split is valid | ||
// With exact divisors, this should never happen, but keep as safety check | ||
if(batch_offset >= kargs.original_n) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this extra condition is not needed and can impact on the perf
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right! I removed it, it's unnecessary validation check.
// Get the actual split N from transformer | ||
n_per_split = conv_to_gemm_transformer.GetN(); | ||
original_n = conv_to_gemm_transformer.GetOriginalN(); | ||
n_splits = (original_n + n_per_split - 1) / n_per_split; // Calculate number of splits |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use integer_divide_ceil function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you! I updated it.
const index_t input_batch_offset = batch_offset * kargs.input_batch_stride; | ||
const index_t output_batch_offset = batch_offset * kargs.output_batch_stride; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use __builtin_amdgcn_readfirstlane
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, I updated it.
const index_t split_id = blockIdZ; | ||
|
||
// Calculate batch offset for this split | ||
const index_t batch_offset = split_id * kargs.n_per_split; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
__builtin_amdgcn_readfirstlane
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you! I updated it.
- Remove redundant GemmM initialization - Add __builtin_amdgcn_readfirstlane for batch offsets - Remove unnecessary validation checks - Use ck_tile::integer_divide_ceil instead of manual ceiling division
k_batch = args.k_batch; | ||
|
||
GemmM = args.N_ * args.output_spatial_lengths_[0]; | ||
// GemmM will be set after Split-N calculation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you apply same approach for each constructor?
- Fix 32-bit integer overflow in batch offset calculation by using long_index_t - Remove redundant GemmM assignments in 2D/3D constructors for consistency - This fixes crashes when using 6+ splits with large tensors (>2GB)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
## What's New Add Split-N support for grouped convolution forward to handle tensors >2GB by splitting the batch dimension. ## Bug Fix Fixed 32-bit integer overflow that caused crashes with 6+ splits: - Use `long_index_t` for batch offset calculations - Remove redundant GemmM initialization in constructors ## How It Works - Automatically splits batch dimension when tensor exceeds 2GB - Uses grid.z dimension for parallel processing of splits - Each split processes a subset of batches independently ## Testing Verified with tile_example_grouped_conv_fwd: - n=3000 (6 splits) ✓ - n=3500 (7 splits) ✓ - n=10480 (40 splits) ✓
Proposed changes
This PR adds Split-N functionality to the grouped convolution forward operation in CK Tile, enabling efficient processing of large tensors (>2GB) that would otherwise face memory addressing limitations.
Problem solved:
Implementation Details
Core Split-N logic:
grouped_convolution_forward_kernel.hpp
: Added Split-N support with grid.z dimension for batch parallelizationtransform_conv_fwd_to_gemm.hpp
: AddedGetSplitedNSize()
to detect 2GB threshold and calculate optimal number of splitsTesting:
tile_example_grouped_conv_fwd
Checklist
Please put an
x
into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.clang-format
on all changed filesTest Results (MI250)
tile_example_grouped_conv_fwd
:100mb threshold with cpu reference:
2GB threshold without cpu reference:
Discussion
Important Split-N behavior for odd/prime batch sizes
Currently, when N cannot be evenly divided, the implementation falls back to n_per_split=1, meaning each batch element is processed separately:
Example:
This fallback ensures correctness but can lead to inefficient parallelization for prime numbers. Potential improvements for future work:
Split-K conflict prevention
Both Split-K and Split-N use blockIdx.z for parallelization. The current implementation prevents using both simultaneously by checking in the kernel arguments. This ensures correctness but may limit optimization opportunities in some scenarios.