|
| 1 | +/* |
| 2 | + * Licensed to the Apache Software Foundation (ASF) under one |
| 3 | + * or more contributor license agreements. See the NOTICE file |
| 4 | + * distributed with this work for additional information |
| 5 | + * regarding copyright ownership. The ASF licenses this file |
| 6 | + * to you under the Apache License, Version 2.0 (the |
| 7 | + * "License"); you may not use this file except in compliance |
| 8 | + * with the License. You may obtain a copy of the License at |
| 9 | + * |
| 10 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 11 | + * |
| 12 | + * Unless required by applicable law or agreed to in writing, |
| 13 | + * software distributed under the License is distributed on an |
| 14 | + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 15 | + * KIND, either express or implied. See the License for the |
| 16 | + * specific language governing permissions and limitations |
| 17 | + * under the License. |
| 18 | + */ |
| 19 | + |
| 20 | +#include <fstream> |
| 21 | +#include <iostream> |
| 22 | +#include <sstream> |
| 23 | +#include <type_traits> |
| 24 | +#include <variant> |
| 25 | +#include <vector> |
| 26 | + |
| 27 | +#include "../../cuda/cuda_common.h" |
| 28 | + |
| 29 | +// clang-format off |
| 30 | +#include "cutlass/cutlass.h" |
| 31 | + |
| 32 | +#include "cute/tensor.hpp" |
| 33 | +#include "cutlass/float8.h" |
| 34 | +#include "cutlass/tensor_ref.h" |
| 35 | +#include "cutlass/epilogue/collective/default_epilogue.hpp" |
| 36 | +#include "cutlass/epilogue/thread/linear_combination.h" |
| 37 | +#include "cutlass/gemm/dispatch_policy.hpp" |
| 38 | +#include "cutlass/gemm/gemm.h" |
| 39 | +#include "cutlass/gemm/collective/collective_builder.hpp" |
| 40 | +#include "cutlass/epilogue/collective/collective_builder.hpp" |
| 41 | +#include "cutlass/gemm/device/gemm_universal_adapter.h" |
| 42 | +#include "cutlass/gemm/kernel/gemm_universal.hpp" |
| 43 | + |
| 44 | +#include "cutlass_extensions/gemm/collective/collective_builder.hpp" |
| 45 | +#include "cutlass_extensions/gemm/dispatch_policy.hpp" |
| 46 | +// clang-format on |
| 47 | + |
| 48 | +#define CUTLASS_CHECK(status) \ |
| 49 | + { \ |
| 50 | + cutlass::Status error = status; \ |
| 51 | + CHECK(error == cutlass::Status::kSuccess) \ |
| 52 | + << "Got cutlass error: " << cutlassGetStatusString(error); \ |
| 53 | + } |
| 54 | + |
| 55 | +using namespace cute; |
| 56 | +using ProblemShape = Shape<int, int, int, int>; |
| 57 | +using tvm::runtime::NDArray; |
| 58 | + |
| 59 | +template <typename TileShape, typename ClusterShape, typename ElementD, typename SchedulerType, |
| 60 | + int ScaleGranularityM = 1> |
| 61 | +struct CutlassFP8ScaledBlockwiseGemmRunner { |
| 62 | + using ElementAccumulator = float; |
| 63 | + using ElementCompute = float; |
| 64 | + using ElementBlockScale = float; |
| 65 | + |
| 66 | + using ElementA = cutlass::float_e4m3_t; |
| 67 | + using LayoutA = cutlass::layout::RowMajor; |
| 68 | + static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; |
| 69 | + |
| 70 | + using ElementB = cutlass::float_e4m3_t; |
| 71 | + using LayoutB = cutlass::layout::ColumnMajor; |
| 72 | + static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; |
| 73 | + |
| 74 | + using ElementC = void; |
| 75 | + using LayoutC = cutlass::layout::RowMajor; |
| 76 | + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementD>::value; |
| 77 | + |
| 78 | + using LayoutD = cutlass::layout::RowMajor; |
| 79 | + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; |
| 80 | + |
| 81 | + using ArchTag = cutlass::arch::Sm90; |
| 82 | + using OperatorClass = cutlass::arch::OpClassTensorOp; |
| 83 | + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; |
| 84 | + using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; |
| 85 | + using StoreEpilogueCompute = |
| 86 | + typename cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90AccFetch>; |
| 87 | + |
| 88 | + using KernelSchedule = |
| 89 | + cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum< |
| 90 | + ScaleGranularityM>; |
| 91 | + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< |
| 92 | + ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType, ElementAccumulator, |
| 93 | + ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, |
| 94 | + EpilogueSchedule, StoreEpilogueCompute>::CollectiveOp; |
| 95 | + |
| 96 | + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< |
| 97 | + ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB, LayoutB, AlignmentB, |
| 98 | + ElementAccumulator, TileShape, ClusterShape, |
| 99 | + cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>( |
| 100 | + sizeof(typename CollectiveEpilogue::SharedStorage))>, |
| 101 | + KernelSchedule>::CollectiveOp; |
| 102 | + |
| 103 | + using GemmKernel = |
| 104 | + cutlass::gemm::kernel::GemmUniversal<Shape<int, int, int, int>, // Indicates ProblemShape |
| 105 | + CollectiveMainloop, CollectiveEpilogue, SchedulerType>; |
| 106 | + using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>; |
| 107 | + |
| 108 | + using StrideA = typename Gemm::GemmKernel::StrideA; |
| 109 | + using StrideB = typename Gemm::GemmKernel::StrideB; |
| 110 | + using StrideD = typename Gemm::GemmKernel::StrideD; |
| 111 | + |
| 112 | + void run_gemm(const ElementA* a_ptr, const ElementB* b_ptr, const ElementBlockScale* scales_a_ptr, |
| 113 | + const ElementBlockScale* scales_b_ptr, ElementD* o_ptr, ProblemShape* problem_size, |
| 114 | + StrideA* stride_a, StrideB* stride_b, StrideD* stride_d, uint8_t* workspace, |
| 115 | + int64_t workspace_size, cudaStream_t stream) { |
| 116 | + cutlass::KernelHardwareInfo hw_info; |
| 117 | + hw_info.device_id = 0; |
| 118 | + hw_info.sm_count = |
| 119 | + cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); |
| 120 | + |
| 121 | + typename Gemm::GemmKernel::TileSchedulerArguments scheduler; |
| 122 | + static constexpr bool UsesStreamKScheduler = |
| 123 | + cute::is_same_v<typename Gemm::GemmKernel::TileSchedulerTag, |
| 124 | + cutlass::gemm::StreamKScheduler>; |
| 125 | + if constexpr (UsesStreamKScheduler) { |
| 126 | + using DecompositionMode = typename cutlass::gemm::kernel::detail:: |
| 127 | + PersistentTileSchedulerSm90StreamKParams::DecompositionMode; |
| 128 | + using ReductionMode = typename cutlass::gemm::kernel::detail:: |
| 129 | + PersistentTileSchedulerSm90StreamKParams::ReductionMode; |
| 130 | + scheduler.decomposition_mode = DecompositionMode::StreamK; |
| 131 | + scheduler.reduction_mode = ReductionMode::Nondeterministic; |
| 132 | + } |
| 133 | + |
| 134 | + typename Gemm::Arguments arguments = { |
| 135 | + cutlass::gemm::GemmUniversalMode::kGemm, |
| 136 | + *problem_size, |
| 137 | + {a_ptr, *stride_a, b_ptr, *stride_b, scales_a_ptr, scales_b_ptr}, |
| 138 | + {{}, nullptr, *stride_d, o_ptr, *stride_d}, |
| 139 | + hw_info, |
| 140 | + scheduler}; |
| 141 | + |
| 142 | + Gemm gemm_op; |
| 143 | + CUTLASS_CHECK(gemm_op.can_implement(arguments)); |
| 144 | + CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments)); |
| 145 | + CUTLASS_CHECK(gemm_op.initialize(arguments, workspace, stream)); |
| 146 | + CUTLASS_CHECK(gemm_op.run(stream)); |
| 147 | + } |
| 148 | +}; |
| 149 | + |
| 150 | +template <typename TileShape, typename ClusterShape, typename ElementA, typename ElementB, |
| 151 | + typename ElementD, typename ElementBlockScale> |
| 152 | +void cutlass_fp8_blockwise_scaled_gemm(ElementA* a, ElementB* b, ElementBlockScale* scales_a, |
| 153 | + ElementBlockScale* scales_b, ElementD* out, |
| 154 | + uint8_t* workspace, int64_t workspace_size, int64_t m, |
| 155 | + int64_t n, int64_t k, cudaStream_t stream) { |
| 156 | + if (k > 3 * n) { |
| 157 | + using SchedulerType = cutlass::gemm::StreamKScheduler; |
| 158 | + using Runner = |
| 159 | + CutlassFP8ScaledBlockwiseGemmRunner<TileShape, ClusterShape, ElementD, SchedulerType>; |
| 160 | + using StrideA = typename Runner::StrideA; |
| 161 | + using StrideB = typename Runner::StrideB; |
| 162 | + using StrideD = typename Runner::StrideD; |
| 163 | + |
| 164 | + Runner runner; |
| 165 | + StrideA stride_a = cute::make_stride(k, Int<1>{}, m * k); |
| 166 | + StrideB stride_b = cute::make_stride(k, Int<1>{}, n * k); |
| 167 | + StrideD stride_d = cute::make_stride(n, Int<1>{}, m * n); |
| 168 | + ProblemShape problem_size{static_cast<int>(m), static_cast<int>(n), static_cast<int>(k), 1}; |
| 169 | + runner.run_gemm(a, b, scales_a, scales_b, out, &problem_size, &stride_a, &stride_b, &stride_d, |
| 170 | + workspace, workspace_size, stream); |
| 171 | + } else { |
| 172 | + using SchedulerType = cutlass::gemm::PersistentScheduler; |
| 173 | + using Runner = |
| 174 | + CutlassFP8ScaledBlockwiseGemmRunner<TileShape, ClusterShape, ElementD, SchedulerType>; |
| 175 | + using StrideA = typename Runner::StrideA; |
| 176 | + using StrideB = typename Runner::StrideB; |
| 177 | + using StrideD = typename Runner::StrideD; |
| 178 | + |
| 179 | + Runner runner; |
| 180 | + StrideA stride_a = cute::make_stride(k, Int<1>{}, m * k); |
| 181 | + StrideB stride_b = cute::make_stride(k, Int<1>{}, n * k); |
| 182 | + StrideD stride_d = cute::make_stride(n, Int<1>{}, m * n); |
| 183 | + ProblemShape problem_size{static_cast<int>(m), static_cast<int>(n), static_cast<int>(k), 1}; |
| 184 | + runner.run_gemm(a, b, scales_a, scales_b, out, &problem_size, &stride_a, &stride_b, &stride_d, |
| 185 | + workspace, workspace_size, stream); |
| 186 | + } |
| 187 | +} |
| 188 | + |
| 189 | +template <typename TileShape, typename ClusterShape, typename ElementA, typename ElementB, |
| 190 | + typename ElementD, typename ElementBlockScale> |
| 191 | +void cutlass_fp8_blockwise_scaled_bmm(ElementA* a, ElementB* b, ElementBlockScale* scales_a, |
| 192 | + ElementBlockScale* scales_b, ElementD* out, |
| 193 | + uint8_t* workspace, int64_t workspace_size, int64_t m, |
| 194 | + int64_t n, int64_t k, int64_t l, cudaStream_t stream) { |
| 195 | + if (k > 3 * n) { |
| 196 | + using SchedulerType = cutlass::gemm::StreamKScheduler; |
| 197 | + using Runner = |
| 198 | + CutlassFP8ScaledBlockwiseGemmRunner<TileShape, ClusterShape, ElementD, SchedulerType>; |
| 199 | + using StrideA = typename Runner::StrideA; |
| 200 | + using StrideB = typename Runner::StrideB; |
| 201 | + using StrideD = typename Runner::StrideD; |
| 202 | + |
| 203 | + Runner runner; |
| 204 | + StrideA stride_a = cute::make_stride(k, Int<1>{}, m * k); |
| 205 | + StrideB stride_b = cute::make_stride(k, Int<1>{}, n * k); |
| 206 | + StrideD stride_d = cute::make_stride(n, Int<1>{}, m * n); |
| 207 | + ProblemShape problem_size{static_cast<int>(m), static_cast<int>(n), static_cast<int>(k), |
| 208 | + static_cast<int>(l)}; |
| 209 | + runner.run_gemm(a, b, scales_a, scales_b, out, &problem_size, &stride_a, &stride_b, &stride_d, |
| 210 | + workspace, workspace_size, stream); |
| 211 | + } else { |
| 212 | + using SchedulerType = cutlass::gemm::PersistentScheduler; |
| 213 | + using Runner = |
| 214 | + CutlassFP8ScaledBlockwiseGemmRunner<TileShape, ClusterShape, ElementD, SchedulerType>; |
| 215 | + using StrideA = typename Runner::StrideA; |
| 216 | + using StrideB = typename Runner::StrideB; |
| 217 | + using StrideD = typename Runner::StrideD; |
| 218 | + |
| 219 | + Runner runner; |
| 220 | + StrideA stride_a = cute::make_stride(k, Int<1>{}, m * k); |
| 221 | + StrideB stride_b = cute::make_stride(k, Int<1>{}, n * k); |
| 222 | + StrideD stride_d = cute::make_stride(n, Int<1>{}, m * n); |
| 223 | + ProblemShape problem_size{static_cast<int>(m), static_cast<int>(n), static_cast<int>(k), |
| 224 | + static_cast<int>(l)}; |
| 225 | + runner.run_gemm(a, b, scales_a, scales_b, out, &problem_size, &stride_a, &stride_b, &stride_d, |
| 226 | + workspace, workspace_size, stream); |
| 227 | + } |
| 228 | +} |
0 commit comments