Skip to content

Commit 731c671

Browse files
committed
[CUTLASS] Add blockwise scale gemm/bmm kernels
This PR introduces blockwise scale matmul and batch matmul CUTLASS kernels, adapted from SGLang (http://github.com/sgl-project/sglang), vLLM (https://github.com/vllm-project/vllm) and https://github.com/soundOfDestiny/cutlass. We add unit tests for gemm and bmm. This PR also restores some cutlass gemm tests that were removed before during Relay phasing out.
1 parent 3f985b5 commit 731c671

File tree

7 files changed

+778
-10
lines changed

7 files changed

+778
-10
lines changed

3rdparty/cutlass

Submodule cutlass updated 2252 files

cmake/modules/contrib/CUTLASS.cmake

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,15 @@ if(USE_CUDA AND USE_CUTLASS)
5858
list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp16_group_gemm.cu)
5959
list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp8_group_gemm.cu)
6060
list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp8_gemm.cu)
61+
list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp8_blockwise_scaled_gemm.cu)
6162
endif()
6263
if(TVM_CUTLASS_RUNTIME_SRCS)
6364
add_library(tvm_cutlass_objs OBJECT ${TVM_CUTLASS_RUNTIME_SRCS})
6465
target_compile_options(tvm_cutlass_objs PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr>)
65-
target_include_directories(tvm_cutlass_objs PRIVATE ${CUTLASS_DIR}/include)
66+
target_include_directories(tvm_cutlass_objs PRIVATE
67+
${CUTLASS_DIR}/include
68+
${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm/cutlass_extensions/include
69+
)
6670
target_compile_definitions(tvm_cutlass_objs PRIVATE DMLC_USE_LOGGING_LIBRARY=<tvm/runtime/logging.h>)
6771
list(APPEND CUTLASS_RUNTIME_OBJS "$<${CUTLASS_GEN_COND}:$<TARGET_OBJECTS:tvm_cutlass_objs>>")
6872
endif()
Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
#include <fstream>
21+
#include <iostream>
22+
#include <sstream>
23+
#include <type_traits>
24+
#include <variant>
25+
#include <vector>
26+
27+
#include "../../cuda/cuda_common.h"
28+
29+
// clang-format off
30+
#include "cutlass/cutlass.h"
31+
32+
#include "cute/tensor.hpp"
33+
#include "cutlass/float8.h"
34+
#include "cutlass/tensor_ref.h"
35+
#include "cutlass/epilogue/collective/default_epilogue.hpp"
36+
#include "cutlass/epilogue/thread/linear_combination.h"
37+
#include "cutlass/gemm/dispatch_policy.hpp"
38+
#include "cutlass/gemm/gemm.h"
39+
#include "cutlass/gemm/collective/collective_builder.hpp"
40+
#include "cutlass/epilogue/collective/collective_builder.hpp"
41+
#include "cutlass/gemm/device/gemm_universal_adapter.h"
42+
#include "cutlass/gemm/kernel/gemm_universal.hpp"
43+
44+
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
45+
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
46+
// clang-format on
47+
48+
#define CUTLASS_CHECK(status) \
49+
{ \
50+
cutlass::Status error = status; \
51+
CHECK(error == cutlass::Status::kSuccess) \
52+
<< "Got cutlass error: " << cutlassGetStatusString(error); \
53+
}
54+
55+
using namespace cute;
56+
using ProblemShape = Shape<int, int, int, int>;
57+
using tvm::runtime::NDArray;
58+
59+
template <typename TileShape, typename ClusterShape, typename ElementD, typename SchedulerType,
60+
int ScaleGranularityM = 1>
61+
struct CutlassFP8ScaledBlockwiseGemmRunner {
62+
using ElementAccumulator = float;
63+
using ElementCompute = float;
64+
using ElementBlockScale = float;
65+
66+
using ElementA = cutlass::float_e4m3_t;
67+
using LayoutA = cutlass::layout::RowMajor;
68+
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
69+
70+
using ElementB = cutlass::float_e4m3_t;
71+
using LayoutB = cutlass::layout::ColumnMajor;
72+
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
73+
74+
using ElementC = void;
75+
using LayoutC = cutlass::layout::RowMajor;
76+
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementD>::value;
77+
78+
using LayoutD = cutlass::layout::RowMajor;
79+
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
80+
81+
using ArchTag = cutlass::arch::Sm90;
82+
using OperatorClass = cutlass::arch::OpClassTensorOp;
83+
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
84+
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
85+
using StoreEpilogueCompute =
86+
typename cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90AccFetch>;
87+
88+
using KernelSchedule =
89+
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<
90+
ScaleGranularityM>;
91+
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
92+
ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType, ElementAccumulator,
93+
ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD,
94+
EpilogueSchedule, StoreEpilogueCompute>::CollectiveOp;
95+
96+
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
97+
ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB, LayoutB, AlignmentB,
98+
ElementAccumulator, TileShape, ClusterShape,
99+
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
100+
sizeof(typename CollectiveEpilogue::SharedStorage))>,
101+
KernelSchedule>::CollectiveOp;
102+
103+
using GemmKernel =
104+
cutlass::gemm::kernel::GemmUniversal<Shape<int, int, int, int>, // Indicates ProblemShape
105+
CollectiveMainloop, CollectiveEpilogue, SchedulerType>;
106+
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
107+
108+
using StrideA = typename Gemm::GemmKernel::StrideA;
109+
using StrideB = typename Gemm::GemmKernel::StrideB;
110+
using StrideD = typename Gemm::GemmKernel::StrideD;
111+
112+
void run_gemm(const ElementA* a_ptr, const ElementB* b_ptr, const ElementBlockScale* scales_a_ptr,
113+
const ElementBlockScale* scales_b_ptr, ElementD* o_ptr, ProblemShape* problem_size,
114+
StrideA* stride_a, StrideB* stride_b, StrideD* stride_d, uint8_t* workspace,
115+
int64_t workspace_size, cudaStream_t stream) {
116+
cutlass::KernelHardwareInfo hw_info;
117+
hw_info.device_id = 0;
118+
hw_info.sm_count =
119+
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
120+
121+
typename Gemm::GemmKernel::TileSchedulerArguments scheduler;
122+
static constexpr bool UsesStreamKScheduler =
123+
cute::is_same_v<typename Gemm::GemmKernel::TileSchedulerTag,
124+
cutlass::gemm::StreamKScheduler>;
125+
if constexpr (UsesStreamKScheduler) {
126+
using DecompositionMode = typename cutlass::gemm::kernel::detail::
127+
PersistentTileSchedulerSm90StreamKParams::DecompositionMode;
128+
using ReductionMode = typename cutlass::gemm::kernel::detail::
129+
PersistentTileSchedulerSm90StreamKParams::ReductionMode;
130+
scheduler.decomposition_mode = DecompositionMode::StreamK;
131+
scheduler.reduction_mode = ReductionMode::Nondeterministic;
132+
}
133+
134+
typename Gemm::Arguments arguments = {
135+
cutlass::gemm::GemmUniversalMode::kGemm,
136+
*problem_size,
137+
{a_ptr, *stride_a, b_ptr, *stride_b, scales_a_ptr, scales_b_ptr},
138+
{{}, nullptr, *stride_d, o_ptr, *stride_d},
139+
hw_info,
140+
scheduler};
141+
142+
Gemm gemm_op;
143+
CUTLASS_CHECK(gemm_op.can_implement(arguments));
144+
CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments));
145+
CUTLASS_CHECK(gemm_op.initialize(arguments, workspace, stream));
146+
CUTLASS_CHECK(gemm_op.run(stream));
147+
}
148+
};
149+
150+
template <typename TileShape, typename ClusterShape, typename ElementA, typename ElementB,
151+
typename ElementD, typename ElementBlockScale>
152+
void cutlass_fp8_blockwise_scaled_gemm(ElementA* a, ElementB* b, ElementBlockScale* scales_a,
153+
ElementBlockScale* scales_b, ElementD* out,
154+
uint8_t* workspace, int64_t workspace_size, int64_t m,
155+
int64_t n, int64_t k, cudaStream_t stream) {
156+
if (k > 3 * n) {
157+
using SchedulerType = cutlass::gemm::StreamKScheduler;
158+
using Runner =
159+
CutlassFP8ScaledBlockwiseGemmRunner<TileShape, ClusterShape, ElementD, SchedulerType>;
160+
using StrideA = typename Runner::StrideA;
161+
using StrideB = typename Runner::StrideB;
162+
using StrideD = typename Runner::StrideD;
163+
164+
Runner runner;
165+
StrideA stride_a = cute::make_stride(k, Int<1>{}, m * k);
166+
StrideB stride_b = cute::make_stride(k, Int<1>{}, n * k);
167+
StrideD stride_d = cute::make_stride(n, Int<1>{}, m * n);
168+
ProblemShape problem_size{static_cast<int>(m), static_cast<int>(n), static_cast<int>(k), 1};
169+
runner.run_gemm(a, b, scales_a, scales_b, out, &problem_size, &stride_a, &stride_b, &stride_d,
170+
workspace, workspace_size, stream);
171+
} else {
172+
using SchedulerType = cutlass::gemm::PersistentScheduler;
173+
using Runner =
174+
CutlassFP8ScaledBlockwiseGemmRunner<TileShape, ClusterShape, ElementD, SchedulerType>;
175+
using StrideA = typename Runner::StrideA;
176+
using StrideB = typename Runner::StrideB;
177+
using StrideD = typename Runner::StrideD;
178+
179+
Runner runner;
180+
StrideA stride_a = cute::make_stride(k, Int<1>{}, m * k);
181+
StrideB stride_b = cute::make_stride(k, Int<1>{}, n * k);
182+
StrideD stride_d = cute::make_stride(n, Int<1>{}, m * n);
183+
ProblemShape problem_size{static_cast<int>(m), static_cast<int>(n), static_cast<int>(k), 1};
184+
runner.run_gemm(a, b, scales_a, scales_b, out, &problem_size, &stride_a, &stride_b, &stride_d,
185+
workspace, workspace_size, stream);
186+
}
187+
}
188+
189+
template <typename TileShape, typename ClusterShape, typename ElementA, typename ElementB,
190+
typename ElementD, typename ElementBlockScale>
191+
void cutlass_fp8_blockwise_scaled_bmm(ElementA* a, ElementB* b, ElementBlockScale* scales_a,
192+
ElementBlockScale* scales_b, ElementD* out,
193+
uint8_t* workspace, int64_t workspace_size, int64_t m,
194+
int64_t n, int64_t k, int64_t l, cudaStream_t stream) {
195+
if (k > 3 * n) {
196+
using SchedulerType = cutlass::gemm::StreamKScheduler;
197+
using Runner =
198+
CutlassFP8ScaledBlockwiseGemmRunner<TileShape, ClusterShape, ElementD, SchedulerType>;
199+
using StrideA = typename Runner::StrideA;
200+
using StrideB = typename Runner::StrideB;
201+
using StrideD = typename Runner::StrideD;
202+
203+
Runner runner;
204+
StrideA stride_a = cute::make_stride(k, Int<1>{}, m * k);
205+
StrideB stride_b = cute::make_stride(k, Int<1>{}, n * k);
206+
StrideD stride_d = cute::make_stride(n, Int<1>{}, m * n);
207+
ProblemShape problem_size{static_cast<int>(m), static_cast<int>(n), static_cast<int>(k),
208+
static_cast<int>(l)};
209+
runner.run_gemm(a, b, scales_a, scales_b, out, &problem_size, &stride_a, &stride_b, &stride_d,
210+
workspace, workspace_size, stream);
211+
} else {
212+
using SchedulerType = cutlass::gemm::PersistentScheduler;
213+
using Runner =
214+
CutlassFP8ScaledBlockwiseGemmRunner<TileShape, ClusterShape, ElementD, SchedulerType>;
215+
using StrideA = typename Runner::StrideA;
216+
using StrideB = typename Runner::StrideB;
217+
using StrideD = typename Runner::StrideD;
218+
219+
Runner runner;
220+
StrideA stride_a = cute::make_stride(k, Int<1>{}, m * k);
221+
StrideB stride_b = cute::make_stride(k, Int<1>{}, n * k);
222+
StrideD stride_d = cute::make_stride(n, Int<1>{}, m * n);
223+
ProblemShape problem_size{static_cast<int>(m), static_cast<int>(n), static_cast<int>(k),
224+
static_cast<int>(l)};
225+
runner.run_gemm(a, b, scales_a, scales_b, out, &problem_size, &stride_a, &stride_b, &stride_d,
226+
workspace, workspace_size, stream);
227+
}
228+
}

0 commit comments

Comments
 (0)