Commit dee185d
[CK_TILE] Stream-K GEMM Implementation (#2781)
* Change splitk_batch_offset parameter to k_size in UniversalGemmKernel::MakeGemmTensorViews function
Prior to this change, the splitk_batch_offset parameter of
MakeGemmTensorViews had type SplitKBatchOffset. But, the only member
variable of the SplitKBatchOffset class used in the MakeGemmTensorViews
function was splitted_k (an int32_t). The splitted_k value was used as
part of defining the dimensions of the tensor view. That said, for
Stream K, we do not need to use the SplitKBatchOffset class since we are
not using Split K. Thus, this commit changes the splitk_batch_offset
parameter to a int32_t called k_size. This will avoid the constraint of
requiring a caller of MakeGemmTensorViews to use the SplitKBatchOffset
class while still providing the same functionality. Calls to
UniversalGemmKernel::MakeGemmTensorViews have been updated accordingly.
* StreamK Kernel RunGemm Implementation
Stream K cannot simply use UniversalGemmKernel's RunGemm for the
following reasons:
1. The UniversalGemmKernel::RunGemm function computes num_loop based on
a static function of the TilePartitioner. That said, for Stream K,
num_loop must be computed using a member function (namely
GetCurrentIterLength from PR #2708).
2. The UniversalGemmKernel::RunGemm function requires the use of a
SplitKBatchOffset object which is not used for Stream K since we are
not using Split K.
Thus, this change adds a RunGemm function in the StreamKKernel class.
* initial implementation for operator() for StreamKKernel: adding stream-k algorithm and calls to RunGemm
* Fix indexing and offset issues for StreamK
These changes do the following:
- Ensure offsets along the M and N dimensions are multiplied by
MPerblock or NPerBlock, respectively. This ensures tile window origins
are at the correct locations.
- Fix bug in the tile partitioner's GetTileIdxWithOffset. Now, we apply
divmod to the given references to ensure correct values are available
to the caller.
- Added documentation in the Stream-K operator()
* Initial gtests for Stream-K
These changes add an initial gtest suite for the CK Tile Stream-K
kernel. Currently, due to bugs in the StreamKTilePartitioner (which will
be handled in a future PR), there are validation issues for certain
cases which may differ on different architectures. Thus, we opted to run
cases that are only fully data-parallel (skipping others). A guard was
added to Stream-K's IsSupportedArgument method to ensure that callers
are aware of this constraint. Additionally, to ensure testing
reproducibility, options for setting the number of CUs and occupancy
were added to MakeKernelArgs.
* Use GemmPipeline operator() variant that takes hot loop and tail num
In Stream-K, the num_loop value varies per WG and per iteration of a
Stream-K loop. So instead, we use the version of the GemmPipeline's
operator() function that takes in has_hot_loop and tail_num. This is
similar to what is done in Grouped GEMM.
* changes from review: comments, move readfirstlane, remove ifndef
* Switch direction of C tensor traversal & add padding guard
Prior to this change, WGs travelled backwards through their assigned
macro tiles in the C tensor. For instance, if WG0 is responsible for C
tiles 0 and 1, it would first visit tile 1 then tile 0. This means that
the iter_end decrements in each iteration of the stream-K while loop.
Since we are working with unsigned integers, the subtraction operation
may not be safe. Thus, this change makes is such that WGs travel forward
so that their iter_start is incremented and their iter_end remains
fixed.
Additionally, we added a guard against WGs that are neither sk_blocks
nor dp_blocks to ensure such WGs do not participate in the GEMM.
Together, these changes make is such that the algorithm is correct when
sk_blocks is greater than zero.
* Disable StreamK_M256_N256_K256_SKBlocks12 test case
This instance involves >=3 WGs contributing to each macro tile in C. Due
to the use of atomics, this is resulting in precision errors. These
errors will not persist once the reduction strategy is implemented. We
will re-enable this test then.
---------
Co-authored-by: Astha Rai <[email protected]>1 parent b7a806f commit dee185d
File tree
10 files changed
+612
-35
lines changed- include/ck_tile/ops/gemm/kernel
- test/ck_tile
- gemm_streamk
10 files changed
+612
-35
lines changedLines changed: 7 additions & 12 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
646 | 646 | | |
647 | 647 | | |
648 | 648 | | |
649 | | - | |
650 | | - | |
| 649 | + | |
651 | 650 | | |
652 | | - | |
653 | | - | |
654 | | - | |
655 | | - | |
656 | | - | |
657 | | - | |
658 | | - | |
| 651 | + | |
| 652 | + | |
| 653 | + | |
| 654 | + | |
| 655 | + | |
659 | 656 | | |
660 | 657 | | |
661 | 658 | | |
| |||
672 | 669 | | |
673 | 670 | | |
674 | 671 | | |
675 | | - | |
676 | | - | |
677 | | - | |
| 672 | + | |
678 | 673 | | |
679 | 674 | | |
680 | 675 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
374 | 374 | | |
375 | 375 | | |
376 | 376 | | |
377 | | - | |
| 377 | + | |
378 | 378 | | |
379 | 379 | | |
380 | 380 | | |
| |||
436 | 436 | | |
437 | 437 | | |
438 | 438 | | |
439 | | - | |
| 439 | + | |
440 | 440 | | |
441 | 441 | | |
442 | 442 | | |
| |||
Lines changed: 146 additions & 11 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
141 | 141 | | |
142 | 142 | | |
143 | 143 | | |
144 | | - | |
| 144 | + | |
| 145 | + | |
| 146 | + | |
| 147 | + | |
| 148 | + | |
| 149 | + | |
| 150 | + | |
| 151 | + | |
| 152 | + | |
| 153 | + | |
145 | 154 | | |
146 | | - | |
147 | | - | |
148 | | - | |
149 | 155 | | |
150 | 156 | | |
151 | 157 | | |
| |||
166 | 172 | | |
167 | 173 | | |
168 | 174 | | |
169 | | - | |
170 | | - | |
| 175 | + | |
| 176 | + | |
171 | 177 | | |
172 | 178 | | |
173 | 179 | | |
174 | | - | |
175 | | - | |
| 180 | + | |
| 181 | + | |
| 182 | + | |
| 183 | + | |
| 184 | + | |
| 185 | + | |
| 186 | + | |
| 187 | + | |
| 188 | + | |
| 189 | + | |
| 190 | + | |
| 191 | + | |
176 | 192 | | |
| 193 | + | |
| 194 | + | |
| 195 | + | |
| 196 | + | |
| 197 | + | |
| 198 | + | |
| 199 | + | |
| 200 | + | |
| 201 | + | |
| 202 | + | |
| 203 | + | |
| 204 | + | |
| 205 | + | |
| 206 | + | |
| 207 | + | |
| 208 | + | |
| 209 | + | |
| 210 | + | |
| 211 | + | |
| 212 | + | |
| 213 | + | |
| 214 | + | |
| 215 | + | |
| 216 | + | |
| 217 | + | |
| 218 | + | |
| 219 | + | |
| 220 | + | |
| 221 | + | |
| 222 | + | |
| 223 | + | |
| 224 | + | |
| 225 | + | |
| 226 | + | |
| 227 | + | |
| 228 | + | |
| 229 | + | |
| 230 | + | |
| 231 | + | |
| 232 | + | |
| 233 | + | |
| 234 | + | |
| 235 | + | |
| 236 | + | |
| 237 | + | |
| 238 | + | |
| 239 | + | |
177 | 240 | | |
178 | 241 | | |
179 | 242 | | |
| |||
199 | 262 | | |
200 | 263 | | |
201 | 264 | | |
202 | | - | |
203 | | - | |
204 | | - | |
| 265 | + | |
| 266 | + | |
| 267 | + | |
| 268 | + | |
| 269 | + | |
| 270 | + | |
| 271 | + | |
| 272 | + | |
| 273 | + | |
| 274 | + | |
| 275 | + | |
| 276 | + | |
| 277 | + | |
| 278 | + | |
| 279 | + | |
| 280 | + | |
| 281 | + | |
| 282 | + | |
| 283 | + | |
| 284 | + | |
| 285 | + | |
| 286 | + | |
| 287 | + | |
| 288 | + | |
| 289 | + | |
| 290 | + | |
| 291 | + | |
| 292 | + | |
| 293 | + | |
| 294 | + | |
| 295 | + | |
| 296 | + | |
| 297 | + | |
| 298 | + | |
| 299 | + | |
| 300 | + | |
| 301 | + | |
| 302 | + | |
| 303 | + | |
| 304 | + | |
| 305 | + | |
| 306 | + | |
| 307 | + | |
| 308 | + | |
| 309 | + | |
| 310 | + | |
| 311 | + | |
| 312 | + | |
| 313 | + | |
| 314 | + | |
| 315 | + | |
| 316 | + | |
| 317 | + | |
| 318 | + | |
| 319 | + | |
| 320 | + | |
| 321 | + | |
| 322 | + | |
| 323 | + | |
| 324 | + | |
| 325 | + | |
| 326 | + | |
| 327 | + | |
| 328 | + | |
| 329 | + | |
| 330 | + | |
| 331 | + | |
| 332 | + | |
| 333 | + | |
| 334 | + | |
| 335 | + | |
| 336 | + | |
| 337 | + | |
| 338 | + | |
| 339 | + | |
205 | 340 | | |
206 | 341 | | |
207 | 342 | | |
| |||
Lines changed: 10 additions & 10 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
579 | 579 | | |
580 | 580 | | |
581 | 581 | | |
582 | | - | |
| 582 | + | |
583 | 583 | | |
584 | 584 | | |
585 | 585 | | |
| |||
591 | 591 | | |
592 | 592 | | |
593 | 593 | | |
594 | | - | |
| 594 | + | |
595 | 595 | | |
596 | 596 | | |
597 | 597 | | |
| |||
600 | 600 | | |
601 | 601 | | |
602 | 602 | | |
603 | | - | |
| 603 | + | |
604 | 604 | | |
605 | 605 | | |
606 | 606 | | |
| |||
617 | 617 | | |
618 | 618 | | |
619 | 619 | | |
620 | | - | |
| 620 | + | |
621 | 621 | | |
622 | 622 | | |
623 | 623 | | |
| |||
638 | 638 | | |
639 | 639 | | |
640 | 640 | | |
641 | | - | |
| 641 | + | |
642 | 642 | | |
643 | 643 | | |
644 | 644 | | |
| |||
649 | 649 | | |
650 | 650 | | |
651 | 651 | | |
652 | | - | |
| 652 | + | |
653 | 653 | | |
654 | 654 | | |
655 | 655 | | |
| |||
672 | 672 | | |
673 | 673 | | |
674 | 674 | | |
675 | | - | |
| 675 | + | |
676 | 676 | | |
677 | 677 | | |
678 | 678 | | |
| |||
687 | 687 | | |
688 | 688 | | |
689 | 689 | | |
690 | | - | |
| 690 | + | |
691 | 691 | | |
692 | 692 | | |
693 | 693 | | |
| |||
962 | 962 | | |
963 | 963 | | |
964 | 964 | | |
965 | | - | |
| 965 | + | |
966 | 966 | | |
967 | 967 | | |
968 | 968 | | |
| |||
1018 | 1018 | | |
1019 | 1019 | | |
1020 | 1020 | | |
1021 | | - | |
| 1021 | + | |
1022 | 1022 | | |
1023 | 1023 | | |
1024 | 1024 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
4 | 4 | | |
5 | 5 | | |
6 | 6 | | |
| 7 | + | |
7 | 8 | | |
8 | 9 | | |
9 | 10 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
0 commit comments