Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion example/ck_tile/18_flatmm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ if(has_supported_gpu)
add_executable(tile_example_moe_flatmm EXCLUDE_FROM_ALL moe_flatmm.cpp)
add_executable(tile_example_a16w4_moe_flatmm EXCLUDE_FROM_ALL mixed_prec/a16w4_moe_flatmm.cpp)
add_executable(tile_example_grouped_flatmm EXCLUDE_FROM_ALL grouped_flatmm.cpp)
add_executable(tile_example_mx_flatmm EXCLUDE_FROM_ALL mxgemm/mx_flatmm.cpp) # TODO: 950 only

set(EXAMPLE_FLATMM_COMPILE_OPTIONS)
set(EXAMPLE_MOE_FLATMM_COMPILE_OPTIONS)
Expand All @@ -27,6 +28,6 @@ if(has_supported_gpu)
target_compile_options(tile_example_moe_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})
target_compile_options(tile_example_a16w4_moe_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})
target_compile_options(tile_example_grouped_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})

target_compile_options(tile_example_mx_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) # TODO: 950 only
endif()

506 changes: 506 additions & 0 deletions example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp

Large diffs are not rendered by default.

15 changes: 15 additions & 0 deletions example/ck_tile/18_flatmm/mxgemm/mx_flatmm.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include <string>

#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/flatmm.hpp"
#include "ck_tile/ops/gemm.hpp"

#include "mxfp4_flatmm.hpp"
40 changes: 40 additions & 0 deletions example/ck_tile/18_flatmm/mxgemm/mxfp4_flatmm.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@

// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include "ck_tile/core.hpp"

// GEMM config with 16x16 warp tile
struct MXfp4_FlatmmConfig16
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 512;
static constexpr ck_tile::index_t K_Tile = 256;

static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;

static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 128;

static constexpr bool kPadM = false;
static constexpr bool kPadN = false;
static constexpr bool kPadK = false;

static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;

static constexpr int kBlockPerCu = 1;
static constexpr int TileParitionerGroupNum = 8;
static constexpr int TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool DoubleSmemBuffer = false;

static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
static constexpr bool TiledMMAPermuteN = false;
};
167 changes: 167 additions & 0 deletions example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.

template <typename PrecActType,
typename PrecWeightType,
typename CDataType,
typename FlatmmConfig,
bool UsePersistentKernel = false,
typename ALayout,
typename BLayout,
typename CLayout>
int run_mx_flatmm_with_layouts(int argc,
char* argv[],
const ALayout a_layout = ALayout{},
const BLayout b_layout = BLayout{},
const CLayout c_layout = CLayout{})
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;

using ADataType = PrecActType;
using BDataType = PrecWeightType;
using AccDataType = float;

using ScaleType = ck_tile::e8m0_t;

constexpr int ScaleGranularityM = 1;
constexpr int ScaleGranularityN = 1;
constexpr int ScaleGranularityK = 32;

ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t K = arg_parser.get_int("k");

ck_tile::index_t stride_A = arg_parser.get_int("stride_a");
ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
ck_tile::index_t stride_C = arg_parser.get_int("stride_c");

ck_tile::index_t kbatch = arg_parser.get_int("split_k");
ck_tile::index_t init_method = arg_parser.get_int("init");
ck_tile::index_t n_warmup = arg_parser.get_int("warmup");
ck_tile::index_t n_repeat = arg_parser.get_int("repeat");

stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(c_layout));

auto scale_stride_A = ck_tile::get_default_stride(
M / ScaleGranularityM, K / ScaleGranularityK, 0, is_row_major(a_layout));
auto scale_stride_B = ck_tile::get_default_stride(
K / ScaleGranularityK, N / ScaleGranularityN, 0, is_row_major(b_layout));

if(K % ScaleGranularityK != 0)
throw std::runtime_error("wrong! K must be multiple of ScaleGranularityK.");
if(K % ck_tile::numeric_traits<ADataType>::PackedSize != 0 ||
K % ck_tile::numeric_traits<BDataType>::PackedSize != 0)
throw std::runtime_error("wrong! K must be multiple of packed size.");

ck_tile ::HostTensor<ADataType> a_host(
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
ck_tile::HostTensor<BDataType> b_origin_host(
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
ck_tile::HostTensor<CDataType> c_rslt_host(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));

ck_tile::HostTensor<ScaleType> scale_a(ck_tile::host_tensor_descriptor(
M / ScaleGranularityM, K / ScaleGranularityK, scale_stride_A, is_row_major(a_layout)));
ck_tile::HostTensor<ScaleType> scale_b(ck_tile::host_tensor_descriptor(
K / ScaleGranularityK, N / ScaleGranularityN, scale_stride_B, is_row_major(b_layout)));

if(init_method == 0)
{
ck_tile::FillUniformDistribution<ADataType>{0.0f, 1.0f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_origin_host);
ck_tile::FillUniformDistribution<ScaleType>{-2.f, 2.f}(scale_a);
ck_tile::FillUniformDistribution<ScaleType>{-2.f, 2.f}(scale_b);
}
else if(init_method == 1)
{
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_origin_host);
ck_tile::FillUniformDistribution<ScaleType>{1.f, 1.f}(scale_a);
ck_tile::FillUniformDistribution<ScaleType>{1.f, 1.f}(scale_b);
}
else
{
throw std::runtime_error("wrong! Unexpected init_method");
}

ck_tile::HostTensor<BDataType> b_shuffled_host(
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
preShuffleWeight<FlatmmConfig>(b_origin_host.begin(), b_shuffled_host.begin(), N, K);

const auto scale_a_shuffled = preShuffleScale<FlatmmConfig, true>(scale_a);
const auto scale_b_shuffled = preShuffleScale<FlatmmConfig, false>(scale_b);

ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_shuffled_dev_buf(b_shuffled_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_dev_buf(c_rslt_host.get_element_space_size_in_bytes());

ck_tile::DeviceMem scale_a_dev_buf(scale_a_shuffled.get_element_space_size_in_bytes());
ck_tile::DeviceMem scale_b_dev_buf(scale_b_shuffled.get_element_space_size_in_bytes());

a_dev_buf.ToDevice(a_host.data());
b_shuffled_dev_buf.ToDevice(b_shuffled_host.data());
c_rslt_host.SetZero();
scale_a_dev_buf.ToDevice(scale_a_shuffled.data());
scale_b_dev_buf.ToDevice(scale_b_shuffled.data());

auto scale_a_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityM, ScaleGranularityK>{
static_cast<float*>(scale_a_dev_buf.GetDeviceBuffer()), M / ScaleGranularityM};
auto scale_b_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK>{
static_cast<float*>(scale_b_dev_buf.GetDeviceBuffer()), N / ScaleGranularityN};

invoke_mx_flatmm<FlatmmConfig,
ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ALayout,
BLayout,
ck_tile::tuple<>,
CLayout,
decltype(scale_a_dev_ptr),
decltype(scale_b_dev_ptr),
UsePersistentKernel>(a_dev_buf,
b_shuffled_dev_buf,
c_dev_buf,
M,
N,
K,
stride_A,
stride_B,
stride_C,
kbatch,
scale_a_dev_ptr,
scale_b_dev_ptr,
n_warmup,
n_repeat);

c_dev_buf.FromDevice(c_rslt_host.data());

bool pass = true;
if(arg_parser.get_int("v") == 1)
{
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
c_m_n_host_ref.SetZero();

ck_tile::reference_mx_gemm<ADataType, BDataType, ScaleType, AccDataType, CDataType>(
a_host, b_origin_host, c_m_n_host_ref, scale_a, scale_b);

const float rtol = std::is_same_v<ADataType, ck_tile::half_t> ? 1e-3 : 1e-2;
const float atol = std::is_same_v<ADataType, ck_tile::half_t> ? 1e-3 : 1e-2;

pass = ck_tile::check_err(
c_rslt_host, c_m_n_host_ref, "Error: Incorrect results!", rtol, atol);

std::cout << "Relative error threshold: " << rtol << " Absolute error threshold: " << atol
<< std::endl;
std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl;
}

return pass;
}
87 changes: 87 additions & 0 deletions include/ck_tile/host/reference/reference_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,93 @@ reference_gemm_multiple_abd(const std::array<HostTensor<ADataType>, AsDataType::
make_ParallelTensorFunctor(f_mk_kn_mn, M, N)(std::thread::hardware_concurrency());
}

template <typename ADataType,
typename BDataType,
typename ScaleDataType,
typename AccDataType,
typename CDataType,
typename AElementOp = ck_tile::identity,
typename BElementOp = ck_tile::identity,
typename ACCElementOp = ck_tile::identity>
CK_TILE_HOST void reference_mx_gemm(const HostTensor<ADataType>& a_m_k,
const HostTensor<BDataType>& b_k_n,
HostTensor<CDataType>& c_m_n,
const HostTensor<ScaleDataType>& scale_a,
const HostTensor<ScaleDataType>& scale_b,
const AElementOp& = {},
const BElementOp& = {},
const ACCElementOp& = {})
{
static_assert(std::is_same_v<AElementOp, ck_tile::identity>);
static_assert(std::is_same_v<BElementOp, ck_tile::identity>);
static_assert(std::is_same_v<ACCElementOp, ck_tile::identity>);

const std::size_t M = a_m_k.get_length(0);
const std::size_t N = b_k_n.get_length(1);
const std::size_t K = a_m_k.get_length(1);

const std::size_t ScaleBlockSize = K / scale_a.get_length(1);

HostTensor<AccDataType> a_m_k_scaled({std::size_t(M), std::size_t(K)},
{std::size_t(K), std::size_t(1)});
HostTensor<AccDataType> b_k_n_scaled({std::size_t(K), std::size_t(N)},
{std::size_t(1), std::size_t(K)});

for(std::size_t m = 0; m < M; ++m)
{
for(std::size_t k = 0; k < K; ++k)
{
if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
{
if(k % 2 == 1)
continue; // skip odd k

auto a_f4x2 = a_m_k(m, k);
auto a_scale = ck_tile::type_convert<AccDataType>(scale_a(m, k / ScaleBlockSize));
auto a_f4_lo =
ck_tile::type_convert<AccDataType>(a_f4x2.template unpack<>(number<0>{}));
auto a_f4_hi =
ck_tile::type_convert<AccDataType>(a_f4x2.template unpack<>(number<1>{}));

a_m_k_scaled(m, k) = a_f4_lo * a_scale;
a_m_k_scaled(m, k + 1) = a_f4_hi * a_scale;
}
}
}

for(std::size_t n = 0; n < N; n++)
{
for(std::size_t k = 0; k < K; k++)
{
if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
{
if(k % 2 == 1)
continue; // skip odd k

auto b_f4x2 = b_k_n(k, n);
auto b_scale = ck_tile::type_convert<AccDataType>(scale_b(k / ScaleBlockSize, n));
auto b_f4_lo =
ck_tile::type_convert<AccDataType>(b_f4x2.template unpack<>(number<0>{}));
auto b_f4_hi =
ck_tile::type_convert<AccDataType>(b_f4x2.template unpack<>(number<1>{}));

b_k_n_scaled(k, n) = b_f4_lo * b_scale;
b_k_n_scaled(k + 1, n) = b_f4_hi * b_scale;
}
else
{
b_k_n_scaled(k, n) =
ck_tile::type_convert<AccDataType>((b_k_n(k, n))) *
ck_tile::type_convert<AccDataType>(scale_b(k / ScaleBlockSize, n));
}
}
}

// call reference gemm
reference_gemm<AccDataType, AccDataType, AccDataType, CDataType>(
a_m_k_scaled, b_k_n_scaled, c_m_n);
}

template <typename ADataType,
typename BDataType,
typename DsDataType,
Expand Down
3 changes: 3 additions & 0 deletions include/ck_tile/ops/flatmm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@
#include "ck_tile/ops/flatmm/kernel/grouped_flatmm_kernel.hpp"
#include "ck_tile/ops/flatmm/kernel/mixed_prec_flatmm_kernel.hpp"
#include "ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp"
#include "ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp"
#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp"
#include "ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp"
#include "ck_tile/ops/flatmm/pipeline/moe_flatmm_pipeline_agmem_bgmem_creg.hpp"
#include "ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp"
#include "ck_tile/ops/flatmm/pipeline/tile_flatmm_shape.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
Expand Down
4 changes: 2 additions & 2 deletions include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -902,8 +902,8 @@ struct FlatmmKernel
{
const auto [iM, iN] =
TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(partition_idx);
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);

const SplitKBatchOffset splitk_batch_offset(kargs);
// options
Expand Down
Loading
Loading