Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion example/ck_tile/18_flatmm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ if(has_supported_gpu)
add_executable(tile_example_moe_flatmm EXCLUDE_FROM_ALL moe_flatmm.cpp)
add_executable(tile_example_a16w4_moe_flatmm EXCLUDE_FROM_ALL mixed_prec/a16w4_moe_flatmm.cpp)
add_executable(tile_example_grouped_flatmm EXCLUDE_FROM_ALL grouped_flatmm.cpp)
add_executable(tile_example_mx_flatmm EXCLUDE_FROM_ALL mxgemm/mx_flatmm.cpp) # TODO: 950 only

set(EXAMPLE_FLATMM_COMPILE_OPTIONS)
set(EXAMPLE_MOE_FLATMM_COMPILE_OPTIONS)
Expand All @@ -27,6 +28,6 @@ if(has_supported_gpu)
target_compile_options(tile_example_moe_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})
target_compile_options(tile_example_a16w4_moe_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})
target_compile_options(tile_example_grouped_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})

target_compile_options(tile_example_mx_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) # TODO: 950 only
endif()

506 changes: 506 additions & 0 deletions example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp

Large diffs are not rendered by default.

15 changes: 15 additions & 0 deletions example/ck_tile/18_flatmm/mxgemm/mx_flatmm.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include <string>

#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/flatmm.hpp"
#include "ck_tile/ops/gemm.hpp"

#include "mxfp4_flatmm.hpp"
40 changes: 40 additions & 0 deletions example/ck_tile/18_flatmm/mxgemm/mxfp4_flatmm.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@

// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include "ck_tile/core.hpp"

// GEMM config with 16x16 warp tile
struct MXfp4_FlatmmConfig16
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 512;
static constexpr ck_tile::index_t K_Tile = 256;

static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;

static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 128;

static constexpr bool kPadM = false;
static constexpr bool kPadN = false;
static constexpr bool kPadK = false;

static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;

static constexpr int kBlockPerCu = 1;
static constexpr int TileParitionerGroupNum = 8;
static constexpr int TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool DoubleSmemBuffer = false;

static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
static constexpr bool TiledMMAPermuteN = false;
};
167 changes: 167 additions & 0 deletions example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.

template <typename PrecActType,
typename PrecWeightType,
typename CDataType,
typename FlatmmConfig,
bool UsePersistentKernel = false,
typename ALayout,
typename BLayout,
typename CLayout>
int run_mx_flatmm_with_layouts(int argc,
char* argv[],
const ALayout a_layout = ALayout{},
const BLayout b_layout = BLayout{},
const CLayout c_layout = CLayout{})
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;

using ADataType = PrecActType;
using BDataType = PrecWeightType;
using AccDataType = float;

using ScaleType = ck_tile::e8m0_t;

constexpr int ScaleGranularityM = 1;
constexpr int ScaleGranularityN = 1;
constexpr int ScaleGranularityK = 32;

ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t K = arg_parser.get_int("k");

ck_tile::index_t stride_A = arg_parser.get_int("stride_a");
ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
ck_tile::index_t stride_C = arg_parser.get_int("stride_c");

ck_tile::index_t kbatch = arg_parser.get_int("split_k");
ck_tile::index_t init_method = arg_parser.get_int("init");
ck_tile::index_t n_warmup = arg_parser.get_int("warmup");
ck_tile::index_t n_repeat = arg_parser.get_int("repeat");

stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(c_layout));

auto scale_stride_A = ck_tile::get_default_stride(
M / ScaleGranularityM, K / ScaleGranularityK, 0, is_row_major(a_layout));
auto scale_stride_B = ck_tile::get_default_stride(
K / ScaleGranularityK, N / ScaleGranularityN, 0, is_row_major(b_layout));

if(K % ScaleGranularityK != 0)
throw std::runtime_error("wrong! K must be multiple of ScaleGranularityK.");
if(K % ck_tile::numeric_traits<ADataType>::PackedSize != 0 ||
K % ck_tile::numeric_traits<BDataType>::PackedSize != 0)
throw std::runtime_error("wrong! K must be multiple of packed size.");

ck_tile ::HostTensor<ADataType> a_host(
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
ck_tile::HostTensor<BDataType> b_origin_host(
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
ck_tile::HostTensor<CDataType> c_rslt_host(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));

ck_tile::HostTensor<ScaleType> scale_a(ck_tile::host_tensor_descriptor(
M / ScaleGranularityM, K / ScaleGranularityK, scale_stride_A, is_row_major(a_layout)));
ck_tile::HostTensor<ScaleType> scale_b(ck_tile::host_tensor_descriptor(
K / ScaleGranularityK, N / ScaleGranularityN, scale_stride_B, is_row_major(b_layout)));

if(init_method == 0)
{
ck_tile::FillUniformDistribution<ADataType>{0.0f, 1.0f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_origin_host);
ck_tile::FillUniformDistribution<ScaleType>{-2.f, 2.f}(scale_a);
ck_tile::FillUniformDistribution<ScaleType>{-2.f, 2.f}(scale_b);
}
else if(init_method == 1)
{
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_origin_host);
ck_tile::FillUniformDistribution<ScaleType>{1.f, 1.f}(scale_a);
ck_tile::FillUniformDistribution<ScaleType>{1.f, 1.f}(scale_b);
}
else
{
throw std::runtime_error("wrong! Unexpected init_method");
}

ck_tile::HostTensor<BDataType> b_shuffled_host(
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
preShuffleWeight<FlatmmConfig>(b_origin_host.begin(), b_shuffled_host.begin(), N, K);

const auto scale_a_shuffled = preShuffleScale<FlatmmConfig, true>(scale_a);
const auto scale_b_shuffled = preShuffleScale<FlatmmConfig, false>(scale_b);

ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_shuffled_dev_buf(b_shuffled_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_dev_buf(c_rslt_host.get_element_space_size_in_bytes());

ck_tile::DeviceMem scale_a_dev_buf(scale_a_shuffled.get_element_space_size_in_bytes());
ck_tile::DeviceMem scale_b_dev_buf(scale_b_shuffled.get_element_space_size_in_bytes());

a_dev_buf.ToDevice(a_host.data());
b_shuffled_dev_buf.ToDevice(b_shuffled_host.data());
c_rslt_host.SetZero();
scale_a_dev_buf.ToDevice(scale_a_shuffled.data());
scale_b_dev_buf.ToDevice(scale_b_shuffled.data());

auto scale_a_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityM, ScaleGranularityK>{
static_cast<float*>(scale_a_dev_buf.GetDeviceBuffer()), M / ScaleGranularityM};
auto scale_b_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK>{
static_cast<float*>(scale_b_dev_buf.GetDeviceBuffer()), N / ScaleGranularityN};

invoke_mx_flatmm<FlatmmConfig,
ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ALayout,
BLayout,
ck_tile::tuple<>,
CLayout,
decltype(scale_a_dev_ptr),
decltype(scale_b_dev_ptr),
UsePersistentKernel>(a_dev_buf,
b_shuffled_dev_buf,
c_dev_buf,
M,
N,
K,
stride_A,
stride_B,
stride_C,
kbatch,
scale_a_dev_ptr,
scale_b_dev_ptr,
n_warmup,
n_repeat);

c_dev_buf.FromDevice(c_rslt_host.data());

bool pass = true;
if(arg_parser.get_int("v") == 1)
{
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
c_m_n_host_ref.SetZero();

ck_tile::reference_mx_gemm<ADataType, BDataType, ScaleType, AccDataType, CDataType>(
a_host, b_origin_host, c_m_n_host_ref, scale_a, scale_b);

const float rtol = std::is_same_v<ADataType, ck_tile::half_t> ? 1e-3 : 1e-2;
const float atol = std::is_same_v<ADataType, ck_tile::half_t> ? 1e-3 : 1e-2;

pass = ck_tile::check_err(
c_rslt_host, c_m_n_host_ref, "Error: Incorrect results!", rtol, atol);

std::cout << "Relative error threshold: " << rtol << " Absolute error threshold: " << atol
<< std::endl;
std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl;
}

return pass;
}
87 changes: 87 additions & 0 deletions include/ck_tile/host/reference/reference_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,93 @@ reference_gemm_multiple_abd(const std::array<HostTensor<ADataType>, AsDataType::
make_ParallelTensorFunctor(f_mk_kn_mn, M, N)(std::thread::hardware_concurrency());
}

template <typename ADataType,
typename BDataType,
typename ScaleDataType,
typename AccDataType,
typename CDataType,
typename AElementOp = ck_tile::identity,
typename BElementOp = ck_tile::identity,
typename ACCElementOp = ck_tile::identity>
CK_TILE_HOST void reference_mx_gemm(const HostTensor<ADataType>& a_m_k,
const HostTensor<BDataType>& b_k_n,
HostTensor<CDataType>& c_m_n,
const HostTensor<ScaleDataType>& scale_a,
const HostTensor<ScaleDataType>& scale_b,
const AElementOp& = {},
const BElementOp& = {},
const ACCElementOp& = {})
{
static_assert(std::is_same_v<AElementOp, ck_tile::identity>);
static_assert(std::is_same_v<BElementOp, ck_tile::identity>);
static_assert(std::is_same_v<ACCElementOp, ck_tile::identity>);

const std::size_t M = a_m_k.get_length(0);
const std::size_t N = b_k_n.get_length(1);
const std::size_t K = a_m_k.get_length(1);

const std::size_t ScaleBlockSize = K / scale_a.get_length(1);

HostTensor<AccDataType> a_m_k_scaled({std::size_t(M), std::size_t(K)},
{std::size_t(K), std::size_t(1)});
HostTensor<AccDataType> b_k_n_scaled({std::size_t(K), std::size_t(N)},
{std::size_t(1), std::size_t(K)});

for(std::size_t m = 0; m < M; ++m)
{
for(std::size_t k = 0; k < K; ++k)
{
if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
{
if(k % 2 == 1)
continue; // skip odd k

auto a_f4x2 = a_m_k(m, k);
auto a_scale = ck_tile::type_convert<AccDataType>(scale_a(m, k / ScaleBlockSize));
auto a_f4_lo =
ck_tile::type_convert<AccDataType>(a_f4x2.template unpack<>(number<0>{}));
auto a_f4_hi =
ck_tile::type_convert<AccDataType>(a_f4x2.template unpack<>(number<1>{}));

a_m_k_scaled(m, k) = a_f4_lo * a_scale;
a_m_k_scaled(m, k + 1) = a_f4_hi * a_scale;
}
}
}

for(std::size_t n = 0; n < N; n++)
{
for(std::size_t k = 0; k < K; k++)
{
if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
{
if(k % 2 == 1)
continue; // skip odd k

auto b_f4x2 = b_k_n(k, n);
auto b_scale = ck_tile::type_convert<AccDataType>(scale_b(k / ScaleBlockSize, n));
auto b_f4_lo =
ck_tile::type_convert<AccDataType>(b_f4x2.template unpack<>(number<0>{}));
auto b_f4_hi =
ck_tile::type_convert<AccDataType>(b_f4x2.template unpack<>(number<1>{}));

b_k_n_scaled(k, n) = b_f4_lo * b_scale;
b_k_n_scaled(k + 1, n) = b_f4_hi * b_scale;
}
else
{
b_k_n_scaled(k, n) =
ck_tile::type_convert<AccDataType>((b_k_n(k, n))) *
ck_tile::type_convert<AccDataType>(scale_b(k / ScaleBlockSize, n));
}
}
}

// call reference gemm
reference_gemm<AccDataType, AccDataType, AccDataType, CDataType>(
a_m_k_scaled, b_k_n_scaled, c_m_n);
}

template <typename ADataType,
typename BDataType,
typename DsDataType,
Expand Down
3 changes: 3 additions & 0 deletions include/ck_tile/ops/flatmm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@
#include "ck_tile/ops/flatmm/kernel/grouped_flatmm_kernel.hpp"
#include "ck_tile/ops/flatmm/kernel/mixed_prec_flatmm_kernel.hpp"
#include "ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp"
#include "ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp"
#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp"
#include "ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp"
#include "ck_tile/ops/flatmm/pipeline/moe_flatmm_pipeline_agmem_bgmem_creg.hpp"
#include "ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp"
#include "ck_tile/ops/flatmm/pipeline/tile_flatmm_shape.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
Expand Down
4 changes: 2 additions & 2 deletions include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -902,8 +902,8 @@ struct FlatmmKernel
{
const auto [iM, iN] =
TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(partition_idx);
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);

const SplitKBatchOffset splitk_batch_offset(kargs);
// options
Expand Down
Loading