Skip to content

Commit a267b42

Browse files
authored
Merge branch 'develop' into vpietila/ckb-consistent-naming-of-cmake-test-targets
2 parents d460631 + 66bae43 commit a267b42

File tree

27 files changed

+2165
-285
lines changed

27 files changed

+2165
-285
lines changed

include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp

Lines changed: 172 additions & 115 deletions
Large diffs are not rendered by default.

include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp

Lines changed: 326 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -732,4 +732,330 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
732732
using Base::c_thread_desc_;
733733
};
734734

735+
// Naive pipeline with lowest resource request per WGP
736+
// Implementation with direct load
737+
// GlobalPrefetchStages: 1
738+
// LocalPreFillStages: 1
739+
// LocalPreFetchStages: 0
740+
// LocalSharedMemoryBuffer: 1
741+
742+
template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
743+
index_t BlockSize,
744+
typename ADataType,
745+
typename BDataType,
746+
typename ComputeDataType,
747+
typename AccDataType,
748+
typename ATileDesc,
749+
typename BTileDesc,
750+
typename AMmaTileDesc,
751+
typename BMmaTileDesc,
752+
index_t ABlockTransferSrcScalarPerVector,
753+
index_t BBlockTransferSrcScalarPerVector,
754+
index_t MPerBlock,
755+
index_t NPerBlock,
756+
index_t KPerBlock,
757+
index_t MPerXDL,
758+
index_t NPerXDL,
759+
index_t MRepeat,
760+
index_t NRepeat,
761+
index_t KPacks>
762+
struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1
763+
{
764+
};
765+
766+
template <index_t BlockSize,
767+
typename ADataType,
768+
typename BDataType,
769+
typename ComputeDataType,
770+
typename AccDataType,
771+
typename ATileDesc,
772+
typename BTileDesc,
773+
typename AMmaTileDesc,
774+
typename BMmaTileDesc,
775+
index_t ABlockTransferSrcScalarPerVector,
776+
index_t BBlockTransferSrcScalarPerVector,
777+
index_t MPerBlock,
778+
index_t NPerBlock,
779+
index_t KPerBlock,
780+
index_t MPerXDL,
781+
index_t NPerXDL,
782+
index_t MRepeat,
783+
index_t NRepeat,
784+
index_t KPack
785+
// ,bool TransposeC //disable transposec right now...
786+
>
787+
struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
788+
BlockSize,
789+
ADataType,
790+
BDataType,
791+
ComputeDataType,
792+
AccDataType,
793+
ATileDesc,
794+
BTileDesc,
795+
AMmaTileDesc,
796+
BMmaTileDesc,
797+
ABlockTransferSrcScalarPerVector,
798+
BBlockTransferSrcScalarPerVector,
799+
MPerBlock,
800+
NPerBlock,
801+
KPerBlock,
802+
MPerXDL,
803+
NPerXDL,
804+
MRepeat,
805+
NRepeat,
806+
KPack>
807+
: BlockwiseGemmXdlops_pipeline_base<BlockSize,
808+
ADataType,
809+
BDataType,
810+
ComputeDataType,
811+
AccDataType,
812+
ATileDesc,
813+
BTileDesc,
814+
AMmaTileDesc,
815+
BMmaTileDesc,
816+
ABlockTransferSrcScalarPerVector,
817+
BBlockTransferSrcScalarPerVector,
818+
MPerBlock,
819+
NPerBlock,
820+
KPerBlock,
821+
MPerXDL,
822+
NPerXDL,
823+
MRepeat,
824+
NRepeat,
825+
KPack>
826+
827+
{
828+
using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize,
829+
ADataType,
830+
BDataType,
831+
ComputeDataType,
832+
AccDataType,
833+
ATileDesc,
834+
BTileDesc,
835+
AMmaTileDesc,
836+
BMmaTileDesc,
837+
ABlockTransferSrcScalarPerVector,
838+
BBlockTransferSrcScalarPerVector,
839+
MPerBlock,
840+
NPerBlock,
841+
KPerBlock,
842+
MPerXDL,
843+
NPerXDL,
844+
MRepeat,
845+
NRepeat,
846+
KPack>;
847+
using Base::I0;
848+
using Base::KRepeat;
849+
using Base::xdlops_gemm;
850+
851+
using Base::CalculateCThreadOriginDataIndex;
852+
using Base::CalculateCThreadOriginDataIndex8D;
853+
using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
854+
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
855+
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
856+
using Base::GetCThreadBuffer;
857+
using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
858+
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
859+
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
860+
using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
861+
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
862+
863+
using Base::a_block_desc_m0_m1_m2_k;
864+
using Base::b_block_desc_n0_n1_n2_k;
865+
866+
using Base::AMmaKStride;
867+
using Base::BMmaKStride;
868+
869+
using ComputeDataTypeBuf = typename Base::ComputeDataTypeBuf;
870+
871+
static constexpr index_t PrefetchStages = 1;
872+
static constexpr index_t PrefillStages = 1;
873+
static constexpr index_t GlobalBufferNum = 1;
874+
875+
__host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
876+
{
877+
return num_loop > PrefetchStages;
878+
}
879+
880+
__host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
881+
{
882+
ignore = num_loop;
883+
return TailNumber::Full;
884+
}
885+
886+
template <bool HasMainLoop,
887+
TailNumber TailNum,
888+
typename AGridDesc,
889+
typename ABlockDesc,
890+
typename ABlockTransfer,
891+
typename AGridBuffer,
892+
typename ABlockBuffer,
893+
typename ABlockTransferStep,
894+
typename BGridDesc,
895+
typename BBlockDesc,
896+
typename BBlockTransfer,
897+
typename BGridBuffer,
898+
typename BBlockBuffer,
899+
typename BBlockTransferStep,
900+
typename CThreadBuffer>
901+
__device__ void Run(const AGridDesc& a_grid_desc,
902+
const ABlockDesc& a_block_desc,
903+
ABlockTransfer& a_blockwise_copy,
904+
const AGridBuffer& a_grid_buf,
905+
ABlockBuffer& a_block_buf,
906+
const ABlockTransferStep& a_block_copy_step,
907+
const BGridDesc& b_grid_desc,
908+
const BBlockDesc& b_block_desc,
909+
BBlockTransfer& b_blockwise_copy,
910+
const BGridBuffer& b_grid_buf,
911+
BBlockBuffer& b_block_buf,
912+
const BBlockTransferStep& b_block_copy_step,
913+
CThreadBuffer& c_thread_buf,
914+
index_t num_loop) const
915+
{
916+
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
917+
a_thread_desc_.GetElementSpaceSize());
918+
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
919+
b_thread_desc_.GetElementSpaceSize());
920+
921+
// Global prefetch 1
922+
a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf);
923+
b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf);
924+
925+
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
926+
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
927+
928+
block_sync_lds_direct_load();
929+
930+
// Initialize C
931+
c_thread_buf.Clear();
932+
933+
// main body
934+
if constexpr(HasMainLoop)
935+
{
936+
index_t i = 0;
937+
do
938+
{
939+
static_for<0, KRepeat, 1>{}([&](auto k) {
940+
static_for<0, MRepeat, 1>{}([&](auto m0) {
941+
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
942+
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
943+
a_block_buf,
944+
a_thread_desc_,
945+
make_tuple(m0, I0, k, I0),
946+
a_thread_buf);
947+
static_for<0, NRepeat, 1>{}([&](auto n0) {
948+
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
949+
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
950+
b_block_buf,
951+
b_thread_desc_,
952+
make_tuple(n0, I0, k, I0),
953+
b_thread_buf);
954+
});
955+
});
956+
});
957+
958+
block_sync_lds();
959+
a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf);
960+
b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf);
961+
962+
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
963+
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
964+
965+
static_for<0, KRepeat, 1>{}([&](auto k0) {
966+
static_for<0, MRepeat, 1>{}([&](auto m0) {
967+
static_for<0, NRepeat, 1>{}([&](auto n0) {
968+
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
969+
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
970+
971+
static_for<0, KPack, 1>{}([&](auto ik) {
972+
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
973+
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
974+
make_tuple(m0, I0, k0, ik))>{}];
975+
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
976+
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
977+
make_tuple(n0, I0, k0, ik))>{}];
978+
});
979+
980+
using mfma_input_type =
981+
typename vector_type<ComputeDataTypeBuf,
982+
xdlops_gemm.K1PerXdlops>::type;
983+
984+
constexpr index_t c_offset =
985+
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
986+
987+
xdlops_gemm.Run(
988+
a_thread_vec.template AsType<mfma_input_type>(),
989+
b_thread_vec.template AsType<mfma_input_type>(),
990+
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
991+
});
992+
});
993+
});
994+
995+
block_sync_lds_direct_load();
996+
997+
i += 1;
998+
} while(i < (num_loop - 1));
999+
}
1000+
1001+
// tail
1002+
if constexpr(TailNum == TailNumber::Full)
1003+
{
1004+
static_for<0, KRepeat, 1>{}([&](auto k) {
1005+
static_for<0, MRepeat, 1>{}([&](auto m0) {
1006+
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
1007+
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
1008+
a_block_buf,
1009+
a_thread_desc_,
1010+
make_tuple(m0, I0, k, I0),
1011+
a_thread_buf);
1012+
static_for<0, NRepeat, 1>{}([&](auto n0) {
1013+
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
1014+
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
1015+
b_block_buf,
1016+
b_thread_desc_,
1017+
make_tuple(n0, I0, k, I0),
1018+
b_thread_buf);
1019+
});
1020+
});
1021+
});
1022+
1023+
static_for<0, KRepeat, 1>{}([&](auto k0) {
1024+
static_for<0, MRepeat, 1>{}([&](auto m0) {
1025+
static_for<0, NRepeat, 1>{}([&](auto n0) {
1026+
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
1027+
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
1028+
1029+
static_for<0, KPack, 1>{}([&](auto ik) {
1030+
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
1031+
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
1032+
make_tuple(m0, I0, k0, ik))>{}];
1033+
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
1034+
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
1035+
make_tuple(n0, I0, k0, ik))>{}];
1036+
});
1037+
1038+
using mfma_input_type =
1039+
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
1040+
1041+
constexpr index_t c_offset =
1042+
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
1043+
1044+
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
1045+
b_thread_vec.template AsType<mfma_input_type>(),
1046+
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
1047+
});
1048+
});
1049+
});
1050+
}
1051+
}
1052+
1053+
protected:
1054+
using Base::a_thread_copy_;
1055+
using Base::a_thread_desc_;
1056+
using Base::b_thread_copy_;
1057+
using Base::b_thread_desc_;
1058+
using Base::c_thread_desc_;
1059+
};
1060+
7351061
} // namespace ck

0 commit comments

Comments
 (0)