Skip to content

Conversation

JH-Leon-KIM-AMD
Copy link
Contributor

@JH-Leon-KIM-AMD JH-Leon-KIM-AMD commented Sep 4, 2025

Proposed changes

This PR adds Split-N functionality to the grouped convolution forward operation in CK Tile, enabling efficient processing of large tensors (>2GB) that would otherwise face memory addressing limitations.

Problem solved:

  • GPU tensors exceeding 2GB (2^31 bytes) face memory addressing limitations
  • Split-N automatically partitions the batch dimension when tensors exceed this threshold
  • Enables handling of production workloads with large batch sizes

Implementation Details

Core Split-N logic:

  • grouped_convolution_forward_kernel.hpp: Added Split-N support with grid.z dimension for batch parallelization
  • transform_conv_fwd_to_gemm.hpp: Added GetSplitedNSize() to detect 2GB threshold and calculate optimal number of splits
  • Algorithm finds smallest divisor of N to keep each split under 2GB
  • Kernel uses blockIdx.z to index into split batches

Testing:

  • Verified locally with tile_example_grouped_conv_fwd
  • Split-N activates correctly for tensors >2GB (grid.z > 1)
  • Tested up to 200 splits successfully with modified threshold
  • Unit tests removed from PR due to Jenkins environment differences (tests pass locally)

Checklist

Please put an x into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.

  • I have added tests relevant to the introduced functionality, and the unit tests are passing locally
  • I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, IF the test takes more than 30 seconds to run.
  • I have added inline documentation which enables the maintainers with understanding the motivation
  • I have removed the stale documentation which is no longer relevant after this pull request
  • (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request
  • I have run clang-format on all changed files
  • Any dependent changes have been merged

Test Results (MI250)

  • Local testing with tile_example_grouped_conv_fwd:

100mb threshold with cpu reference:

# N=32, H=W=112 (~200MB) - Split-N active
./bin/tile_example_grouped_conv_fwd -prec=fp16 -g=1 -n=32 -c=256 -k=256 -y=3 -x=3 -h=112 -w=112 -v=0
grid: {6052, 1, 4}, blocks: {256, 1, 1} # Z=4 splits
44.0296 ms, 10.3739 TFlops, ✓ Validation: PASSED

# N=48 (~300MB) - Split-N with 6 splits
grid: {6052, 1, 6}, blocks: {256, 1, 1} # Z=6 splits
66.0836 ms, 10.3678 TFlops, ✓ Validation: PASSED

# N=200 (~1.2GB with 100MB threshold) - Split-N with 200 splits
grid: {1208, 1, 200}, blocks: {256, 1, 1} # Z=200 splits
869.812 ms, 10.42 TFlops, ✓ CPU Validation: PASSED

2GB threshold without cpu reference:

# N=128, C=512, H=W=96 (~1.2GB) - No Split-N
grid: {70688, 1, 2}, blocks: {256, 1, 1} # Z=2 (just under threshold)
510.706 ms, 10.4498 TFlops, ✓ Validation: PASSED

# N=256 (~2.5GB) - Split-N active
grid: {76832, 1, 4}, blocks: {256, 1, 1} # Z=4 splits
1109.5 ms, 10.4563 TFlops, ✓ Validation: PASSED

# N=320 (~3.1GB) - Split-N with 4 splits
grid: {96040, 1, 4}, blocks: {256, 1, 1} # Z=4 splits
1386.74 ms, 10.4573 TFlops, ✓ Validation: PASSED

# N=400 (~3.9GB) - Split-N with 4 splits
grid: {115248, 1, 4}, blocks: {256, 1, 1} # Z=4 splits
1664.27 ms, 10.4562 TFlops, ✓ Validation: PASSED

Discussion

Important Split-N behavior for odd/prime batch sizes

Currently, when N cannot be evenly divided, the implementation falls back to n_per_split=1, meaning each batch element is processed separately:

Example:

  • N=128: Splits evenly into 2×64 (n_splits=2, n_per_split=64) ✓
  • N=127: Cannot split evenly, falls back to 127×1 (n_splits=127, n_per_split=1) ⚠️

This fallback ensures correctness but can lead to inefficient parallelization for prime numbers. Potential improvements for future work:

  1. Padding approach: Add dummy batches to make N divisible
  2. Uneven split handling: Allow splits of different sizes
  3. Relaxed divisibility: Find nearest good divisor and handle remainder

Split-K conflict prevention

Both Split-K and Split-N use blockIdx.z for parallelization. The current implementation prevents using both simultaneously by checking in the kernel arguments. This ensures correctness but may limit optimization opportunities in some scenarios.

@JH-Leon-KIM-AMD JH-Leon-KIM-AMD force-pushed the LWPCK-3657-splitn-support branch 2 times, most recently from f189102 to e64b01c Compare September 5, 2025 09:14
@JH-Leon-KIM-AMD JH-Leon-KIM-AMD force-pushed the LWPCK-3657-splitn-support branch from e64b01c to 75589b5 Compare September 5, 2025 09:20
The Split-N implementation remains functional in:
- include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp
- include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp

Tests pass locally and Split-N verified working with tile_example_grouped_conv_fwd
@JH-Leon-KIM-AMD JH-Leon-KIM-AMD force-pushed the LWPCK-3657-splitn-support branch from 315518e to 950f77d Compare September 6, 2025 20:35
jakpiase
jakpiase previously approved these changes Sep 10, 2025
Copy link
Contributor

@jakpiase jakpiase left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, glad you have tested it! LGTM

@bartekxk bartekxk changed the title Lwpck 3657 splitn support [CK Tile] Grouped conv fwd splitn support Sep 10, 2025
Copy link
Contributor

@bartekxk bartekxk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice


// Note: GemmM will be updated after Split-N calculation
// Initially set to full size, will be adjusted after transformer creation
GemmM = args.N_ * args.output_spatial_lengths_[0];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So what is the reason to initialize GemmM here

Copy link
Contributor Author

@JH-Leon-KIM-AMD JH-Leon-KIM-AMD Sep 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes're right. This initialization is redundant.

return dim3(
TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN), kargs.GemmBatch, kargs.k_batch);
// Ensure n_splits is at least 1 (defensive programming)
const index_t grid_z = (kargs.n_splits == 0 || kargs.n_splits > 1024) ? 1 : kargs.n_splits;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why > 1024?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed this line. It looks over defensive.

  • n_splits can't be 0: Default is 1, and ceiling division always returns ≥1
  • 1024 is arbitrary: No technical reason for this limit, and it could block valid large batch cases


// Check if this split is valid
// With exact divisors, this should never happen, but keep as safety check
if(batch_offset >= kargs.original_n)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this extra condition is not needed and can impact on the perf

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right! I removed it, it's unnecessary validation check.

// Get the actual split N from transformer
n_per_split = conv_to_gemm_transformer.GetN();
original_n = conv_to_gemm_transformer.GetOriginalN();
n_splits = (original_n + n_per_split - 1) / n_per_split; // Calculate number of splits
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use integer_divide_ceil function

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! I updated it.

Comment on lines 851 to 852
const index_t input_batch_offset = batch_offset * kargs.input_batch_stride;
const index_t output_batch_offset = batch_offset * kargs.output_batch_stride;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use __builtin_amdgcn_readfirstlane

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, I updated it.

const index_t split_id = blockIdZ;

// Calculate batch offset for this split
const index_t batch_offset = split_id * kargs.n_per_split;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__builtin_amdgcn_readfirstlane

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! I updated it.

- Remove redundant GemmM initialization
- Add __builtin_amdgcn_readfirstlane for batch offsets
- Remove unnecessary validation checks
- Use ck_tile::integer_divide_ceil instead of manual ceiling division
k_batch = args.k_batch;

GemmM = args.N_ * args.output_spatial_lengths_[0];
// GemmM will be set after Split-N calculation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you apply same approach for each constructor?

- Fix 32-bit integer overflow in batch offset calculation by using long_index_t
- Remove redundant GemmM assignments in 2D/3D constructors for consistency
- This fixes crashes when using 6+ splits with large tensors (>2GB)
Copy link
Contributor

@bartekxk bartekxk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@JH-Leon-KIM-AMD JH-Leon-KIM-AMD merged commit 804065a into develop Sep 16, 2025
37 of 49 checks passed
@JH-Leon-KIM-AMD JH-Leon-KIM-AMD deleted the LWPCK-3657-splitn-support branch September 16, 2025 13:56
AviralGoelAMD pushed a commit that referenced this pull request Sep 21, 2025
## What's New
  Add Split-N support for grouped convolution forward to handle tensors >2GB by splitting the batch dimension.

  ## Bug Fix
  Fixed 32-bit integer overflow that caused crashes with 6+ splits:
  - Use `long_index_t` for batch offset calculations
  - Remove redundant GemmM initialization in constructors

  ## How It Works
  - Automatically splits batch dimension when tensor exceeds 2GB
  - Uses grid.z dimension for parallel processing of splits
  - Each split processes a subset of batches independently

  ## Testing
  Verified with tile_example_grouped_conv_fwd:
  - n=3000 (6 splits) ✓
  - n=3500 (7 splits) ✓
  - n=10480 (40 splits) ✓
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants