Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 33 additions & 23 deletions Test/WMMA/wmma_test_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,19 @@

#if defined(__gfx12__)
#define WMMA_DATA_WIDTH 8
typedef _Float16 frag_type __attribute__( ( ext_vector_type( 8 ) ) );
typedef __fp16 frag_type __attribute__( ( ext_vector_type( 8 ) ) );
typedef float frag_type_c __attribute__( ( ext_vector_type( 8 ) ) );
typedef __fp16 half_2 __attribute__( ( ext_vector_type( 2 ) ) );
#else
#define WMMA_DATA_WIDTH 16
typedef _Float16 frag_type __attribute__( ( ext_vector_type( 16 ) ) );
typedef __fp16 frag_type __attribute__( ( ext_vector_type( 16 ) ) );
typedef __fp16 frag_type_c __attribute__( ( ext_vector_type( 16 ) ) );
#endif

extern "C" __global__ void wmma_matmul( __half* a, __half* b, __half* c )
__device__ half_2 packFp32s( float a, float b ) { return __builtin_amdgcn_cvt_pkrtz( a, b ); }

extern "C" __global__ void wmma_matmul( __fp16* a, __fp16* b, __fp16* c )
{
const int gIdx = blockIdx.x * blockDim.x + threadIdx.x;
const int lIdx = threadIdx.x;

// a and b fragments are stored in 8 VGPRs each, in packed format, so 16 elements each for a and b
Expand All @@ -58,54 +62,60 @@ extern "C" __global__ void wmma_matmul( __half* a, __half* b, __half* c )
frag_type a_frag;
frag_type b_frag;
// initialize c fragment to 0
frag_type c_frag = {};
frag_type_c c_frag = {};

// lane is (0-31) mod 16 instead of 0-31 due to matrix replication in RDNA3
const int lane = lIdx % 16;
const int laneWrapped = lIdx % 16;
const int laneGroup = lIdx / 16;
#if defined( __gfx12__ )
#if 1
for( int ele = 0; ele < WMMA_DATA_WIDTH; ++ele )
{
b_frag[ele] = b[16 * (ele+laneGroup * WMMA_DATA_WIDTH) + lane];
}

for( int ele = 0; ele < WMMA_DATA_WIDTH; ++ele )
{
a_frag[ele] = a[16 * lane + ele+laneGroup * WMMA_DATA_WIDTH];
b_frag[ele] = b[16 * ( ele + laneGroup * WMMA_DATA_WIDTH ) + laneWrapped];
a_frag[ele] = a[16 * laneWrapped + ( ele + laneGroup * WMMA_DATA_WIDTH )];
}
#else
for( int ele = 0; ele < WMMA_DATA_WIDTH; ++ele )
{
b_frag[ele] = b[16 * ele + lane];
{//with __builtin_amdgcn_cvt_pkrtz
half_2* a_ptr = reinterpret_cast<half_2*>( &a_frag );
half_2* b_ptr = reinterpret_cast<half_2*>( &b_frag );
for( int ele = 0; ele < WMMA_DATA_WIDTH / 2; ++ele )
{
const int e0 = ele * 2 + 0;
const int e1 = ele * 2 + 1;
b_ptr[ele] = packFp32s( b[16 * ( e0 + laneGroup * WMMA_DATA_WIDTH ) + laneWrapped], b[16 * ( e1 + laneGroup * WMMA_DATA_WIDTH ) + laneWrapped] );
a_ptr[ele] = packFp32s( a[16 * laneWrapped + ( e0 + laneGroup * WMMA_DATA_WIDTH )], a[16 * laneWrapped + ( e1 + laneGroup * WMMA_DATA_WIDTH )] );
}
}

#endif
#else
// lane is (0-31) mod 16 instead of 0-31 due to matrix replication in RDNA3
for( int ele = 0; ele < WMMA_DATA_WIDTH; ++ele )
{
a_frag[ele] = a[16 * lane + ele];
b_frag[ele] = b[16 * ele + laneWrapped];
a_frag[ele] = a[16 * laneWrapped + ele];
}
#endif
// call the WMMA compiler intrinsic
// more details available in the RDNA3 ISA guide - https://developer.amd.com/wp-content/resources/RDNA3_Shader_ISA_December2022.pdf
// more details available in the RDNA4 ISA guide - https://www.amd.com/content/dam/amd/en/documents/radeon-tech-docs/instruction-set-architectures/rdna4-instruction-set-architecture.pdf
// the last parameter is called "OPSEL" which decides which half of the VGPRs of c_frag the results are stored into
// this will only compile on RDNA3
#if defined( __gfx12__ )
c_frag = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12( a_frag, b_frag, c_frag );
c_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12( a_frag, b_frag, c_frag );
#else
c_frag = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32( a_frag, b_frag, c_frag, false );
#endif
#if defined( __gfx12__ )
for( int ele = 0; ele < WMMA_DATA_WIDTH; ++ele )
{
c[16 * ( ele + laneGroup * WMMA_DATA_WIDTH ) + lane] = c_frag[ele];
c[16 * ( ele + laneGroup * WMMA_DATA_WIDTH ) + laneWrapped] = c_frag[ele];
}
#else
for( int ele = 0; ele < 8; ++ele )
{
const int r = ele * 2 + ( lIdx / 16 );
// store results from unpacked c_frag output
c[16 * r + lane] = c_frag[ele * 2];
c[16 * r + laneWrapped] = c_frag[ele * 2];
// if OPSEL was set to "true", the line above would instead be
// c[16 * r + lane] = c_frag[ele*2 + 1];
// c[16 * r + laneWrapped] = c_frag[ele*2 + 1];
}
#endif
}