@@ -732,4 +732,330 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
732732 using Base::c_thread_desc_;
733733};
734734
735+ // Naive pipeline with lowest resource request per WGP
736+ // Implementation with direct load
737+ // GlobalPrefetchStages: 1
738+ // LocalPreFillStages: 1
739+ // LocalPreFetchStages: 0
740+ // LocalSharedMemoryBuffer: 1
741+
742+ template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
743+ index_t BlockSize,
744+ typename ADataType,
745+ typename BDataType,
746+ typename ComputeDataType,
747+ typename AccDataType,
748+ typename ATileDesc,
749+ typename BTileDesc,
750+ typename AMmaTileDesc,
751+ typename BMmaTileDesc,
752+ index_t ABlockTransferSrcScalarPerVector,
753+ index_t BBlockTransferSrcScalarPerVector,
754+ index_t MPerBlock,
755+ index_t NPerBlock,
756+ index_t KPerBlock,
757+ index_t MPerXDL,
758+ index_t NPerXDL,
759+ index_t MRepeat,
760+ index_t NRepeat,
761+ index_t KPacks>
762+ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1
763+ {
764+ };
765+
766+ template <index_t BlockSize,
767+ typename ADataType,
768+ typename BDataType,
769+ typename ComputeDataType,
770+ typename AccDataType,
771+ typename ATileDesc,
772+ typename BTileDesc,
773+ typename AMmaTileDesc,
774+ typename BMmaTileDesc,
775+ index_t ABlockTransferSrcScalarPerVector,
776+ index_t BBlockTransferSrcScalarPerVector,
777+ index_t MPerBlock,
778+ index_t NPerBlock,
779+ index_t KPerBlock,
780+ index_t MPerXDL,
781+ index_t NPerXDL,
782+ index_t MRepeat,
783+ index_t NRepeat,
784+ index_t KPack
785+ // ,bool TransposeC //disable transposec right now...
786+ >
787+ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1 <BlockGemmPipelineScheduler::Intrawave,
788+ BlockSize,
789+ ADataType,
790+ BDataType,
791+ ComputeDataType,
792+ AccDataType,
793+ ATileDesc,
794+ BTileDesc,
795+ AMmaTileDesc,
796+ BMmaTileDesc,
797+ ABlockTransferSrcScalarPerVector,
798+ BBlockTransferSrcScalarPerVector,
799+ MPerBlock,
800+ NPerBlock,
801+ KPerBlock,
802+ MPerXDL,
803+ NPerXDL,
804+ MRepeat,
805+ NRepeat,
806+ KPack>
807+ : BlockwiseGemmXdlops_pipeline_base<BlockSize,
808+ ADataType,
809+ BDataType,
810+ ComputeDataType,
811+ AccDataType,
812+ ATileDesc,
813+ BTileDesc,
814+ AMmaTileDesc,
815+ BMmaTileDesc,
816+ ABlockTransferSrcScalarPerVector,
817+ BBlockTransferSrcScalarPerVector,
818+ MPerBlock,
819+ NPerBlock,
820+ KPerBlock,
821+ MPerXDL,
822+ NPerXDL,
823+ MRepeat,
824+ NRepeat,
825+ KPack>
826+
827+ {
828+ using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize,
829+ ADataType,
830+ BDataType,
831+ ComputeDataType,
832+ AccDataType,
833+ ATileDesc,
834+ BTileDesc,
835+ AMmaTileDesc,
836+ BMmaTileDesc,
837+ ABlockTransferSrcScalarPerVector,
838+ BBlockTransferSrcScalarPerVector,
839+ MPerBlock,
840+ NPerBlock,
841+ KPerBlock,
842+ MPerXDL,
843+ NPerXDL,
844+ MRepeat,
845+ NRepeat,
846+ KPack>;
847+ using Base::I0;
848+ using Base::KRepeat;
849+ using Base::xdlops_gemm;
850+
851+ using Base::CalculateCThreadOriginDataIndex;
852+ using Base::CalculateCThreadOriginDataIndex8D;
853+ using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
854+ using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
855+ using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
856+ using Base::GetCThreadBuffer;
857+ using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
858+ using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
859+ using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
860+ using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
861+ using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
862+
863+ using Base::a_block_desc_m0_m1_m2_k;
864+ using Base::b_block_desc_n0_n1_n2_k;
865+
866+ using Base::AMmaKStride;
867+ using Base::BMmaKStride;
868+
869+ using ComputeDataTypeBuf = typename Base::ComputeDataTypeBuf;
870+
871+ static constexpr index_t PrefetchStages = 1 ;
872+ static constexpr index_t PrefillStages = 1 ;
873+ static constexpr index_t GlobalBufferNum = 1 ;
874+
875+ __host__ __device__ static constexpr bool BlockHasHotloop (index_t num_loop)
876+ {
877+ return num_loop > PrefetchStages;
878+ }
879+
880+ __host__ __device__ static constexpr TailNumber BlockLoopTailNum (index_t num_loop)
881+ {
882+ ignore = num_loop;
883+ return TailNumber::Full;
884+ }
885+
886+ template <bool HasMainLoop,
887+ TailNumber TailNum,
888+ typename AGridDesc,
889+ typename ABlockDesc,
890+ typename ABlockTransfer,
891+ typename AGridBuffer,
892+ typename ABlockBuffer,
893+ typename ABlockTransferStep,
894+ typename BGridDesc,
895+ typename BBlockDesc,
896+ typename BBlockTransfer,
897+ typename BGridBuffer,
898+ typename BBlockBuffer,
899+ typename BBlockTransferStep,
900+ typename CThreadBuffer>
901+ __device__ void Run (const AGridDesc& a_grid_desc,
902+ const ABlockDesc& a_block_desc,
903+ ABlockTransfer& a_blockwise_copy,
904+ const AGridBuffer& a_grid_buf,
905+ ABlockBuffer& a_block_buf,
906+ const ABlockTransferStep& a_block_copy_step,
907+ const BGridDesc& b_grid_desc,
908+ const BBlockDesc& b_block_desc,
909+ BBlockTransfer& b_blockwise_copy,
910+ const BGridBuffer& b_grid_buf,
911+ BBlockBuffer& b_block_buf,
912+ const BBlockTransferStep& b_block_copy_step,
913+ CThreadBuffer& c_thread_buf,
914+ index_t num_loop) const
915+ {
916+ auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
917+ a_thread_desc_.GetElementSpaceSize ());
918+ auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
919+ b_thread_desc_.GetElementSpaceSize ());
920+
921+ // Global prefetch 1
922+ a_blockwise_copy.Run (a_grid_desc, a_grid_buf, a_block_desc, a_block_buf);
923+ b_blockwise_copy.Run (b_grid_desc, b_grid_buf, b_block_desc, b_block_buf);
924+
925+ a_blockwise_copy.MoveSrcSliceWindow (a_grid_desc, a_block_copy_step);
926+ b_blockwise_copy.MoveSrcSliceWindow (b_grid_desc, b_block_copy_step);
927+
928+ block_sync_lds_direct_load ();
929+
930+ // Initialize C
931+ c_thread_buf.Clear ();
932+
933+ // main body
934+ if constexpr (HasMainLoop)
935+ {
936+ index_t i = 0 ;
937+ do
938+ {
939+ static_for<0 , KRepeat, 1 >{}([&](auto k) {
940+ static_for<0 , MRepeat, 1 >{}([&](auto m0) {
941+ a_thread_copy_.Run (a_block_desc_m0_m1_m2_k,
942+ make_tuple (m0, I0, I0, Number<k * AMmaKStride>{}),
943+ a_block_buf,
944+ a_thread_desc_,
945+ make_tuple (m0, I0, k, I0),
946+ a_thread_buf);
947+ static_for<0 , NRepeat, 1 >{}([&](auto n0) {
948+ b_thread_copy_.Run (b_block_desc_n0_n1_n2_k,
949+ make_tuple (n0, I0, I0, Number<k * BMmaKStride>{}),
950+ b_block_buf,
951+ b_thread_desc_,
952+ make_tuple (n0, I0, k, I0),
953+ b_thread_buf);
954+ });
955+ });
956+ });
957+
958+ block_sync_lds ();
959+ a_blockwise_copy.Run (a_grid_desc, a_grid_buf, a_block_desc, a_block_buf);
960+ b_blockwise_copy.Run (b_grid_desc, b_grid_buf, b_block_desc, b_block_buf);
961+
962+ a_blockwise_copy.MoveSrcSliceWindow (a_grid_desc, a_block_copy_step);
963+ b_blockwise_copy.MoveSrcSliceWindow (b_grid_desc, b_block_copy_step);
964+
965+ static_for<0 , KRepeat, 1 >{}([&](auto k0) {
966+ static_for<0 , MRepeat, 1 >{}([&](auto m0) {
967+ static_for<0 , NRepeat, 1 >{}([&](auto n0) {
968+ vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
969+ vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
970+
971+ static_for<0 , KPack, 1 >{}([&](auto ik) {
972+ a_thread_vec.template AsType <ComputeDataTypeBuf>()(ik) =
973+ a_thread_buf[Number<a_thread_desc_.CalculateOffset (
974+ make_tuple (m0, I0, k0, ik))>{}];
975+ b_thread_vec.template AsType <ComputeDataTypeBuf>()(ik) =
976+ b_thread_buf[Number<b_thread_desc_.CalculateOffset (
977+ make_tuple (n0, I0, k0, ik))>{}];
978+ });
979+
980+ using mfma_input_type =
981+ typename vector_type<ComputeDataTypeBuf,
982+ xdlops_gemm.K1PerXdlops >::type;
983+
984+ constexpr index_t c_offset =
985+ c_thread_desc_.CalculateOffset (make_tuple (m0, n0, 0 ));
986+
987+ xdlops_gemm.Run (
988+ a_thread_vec.template AsType <mfma_input_type>(),
989+ b_thread_vec.template AsType <mfma_input_type>(),
990+ c_thread_buf.GetVectorTypeReference (Number<c_offset>{}));
991+ });
992+ });
993+ });
994+
995+ block_sync_lds_direct_load ();
996+
997+ i += 1 ;
998+ } while (i < (num_loop - 1 ));
999+ }
1000+
1001+ // tail
1002+ if constexpr (TailNum == TailNumber::Full)
1003+ {
1004+ static_for<0 , KRepeat, 1 >{}([&](auto k) {
1005+ static_for<0 , MRepeat, 1 >{}([&](auto m0) {
1006+ a_thread_copy_.Run (a_block_desc_m0_m1_m2_k,
1007+ make_tuple (m0, I0, I0, Number<k * AMmaKStride>{}),
1008+ a_block_buf,
1009+ a_thread_desc_,
1010+ make_tuple (m0, I0, k, I0),
1011+ a_thread_buf);
1012+ static_for<0 , NRepeat, 1 >{}([&](auto n0) {
1013+ b_thread_copy_.Run (b_block_desc_n0_n1_n2_k,
1014+ make_tuple (n0, I0, I0, Number<k * BMmaKStride>{}),
1015+ b_block_buf,
1016+ b_thread_desc_,
1017+ make_tuple (n0, I0, k, I0),
1018+ b_thread_buf);
1019+ });
1020+ });
1021+ });
1022+
1023+ static_for<0 , KRepeat, 1 >{}([&](auto k0) {
1024+ static_for<0 , MRepeat, 1 >{}([&](auto m0) {
1025+ static_for<0 , NRepeat, 1 >{}([&](auto n0) {
1026+ vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
1027+ vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
1028+
1029+ static_for<0 , KPack, 1 >{}([&](auto ik) {
1030+ a_thread_vec.template AsType <ComputeDataTypeBuf>()(ik) =
1031+ a_thread_buf[Number<a_thread_desc_.CalculateOffset (
1032+ make_tuple (m0, I0, k0, ik))>{}];
1033+ b_thread_vec.template AsType <ComputeDataTypeBuf>()(ik) =
1034+ b_thread_buf[Number<b_thread_desc_.CalculateOffset (
1035+ make_tuple (n0, I0, k0, ik))>{}];
1036+ });
1037+
1038+ using mfma_input_type =
1039+ typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops >::type;
1040+
1041+ constexpr index_t c_offset =
1042+ c_thread_desc_.CalculateOffset (make_tuple (m0, n0, 0 ));
1043+
1044+ xdlops_gemm.Run (a_thread_vec.template AsType <mfma_input_type>(),
1045+ b_thread_vec.template AsType <mfma_input_type>(),
1046+ c_thread_buf.GetVectorTypeReference (Number<c_offset>{}));
1047+ });
1048+ });
1049+ });
1050+ }
1051+ }
1052+
1053+ protected:
1054+ using Base::a_thread_copy_;
1055+ using Base::a_thread_desc_;
1056+ using Base::b_thread_copy_;
1057+ using Base::b_thread_desc_;
1058+ using Base::c_thread_desc_;
1059+ };
1060+
7351061} // namespace ck
0 commit comments