Skip to content

Commit 7bee717

Browse files
committed
fix rocm build error
1 parent 890b0f1 commit 7bee717

File tree

5 files changed

+9
-4
lines changed

5 files changed

+9
-4
lines changed

csrc/custom_marlin/gptq_marlin/gptq_marlin.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ template <typename T> inline std::string str(T x) { return std::to_string(x); }
3434

3535
namespace gptq_marlin {
3636

37-
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
37+
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800) || defined(__HIP_PLATFORM_AMD__)
3838

3939
__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
4040
int const* __restrict__ perm_int_ptr,

csrc/custom_marlin/gptq_marlin/gptq_marlin.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ using I4 = Vec<int, 4>;
3838

3939
constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
4040

41-
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
41+
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800) || defined(__HIP_PLATFORM_AMD__)
4242
// No support for async
4343
#else
4444

csrc/custom_marlin/gptq_marlin/gptq_marlin_dtypes.cuh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@
88
#include <cuda_bf16.h>
99
#include <cuda_fp16.h>
1010

11+
#ifdef __HIP_PLATFORM_AMD__
12+
typedef __hip_bfloat16 nv_bfloat16;
13+
typedef __hip_bfloat162 nv_bfloat162;
14+
#endif
15+
1116
namespace gptq_marlin {
1217

1318
template <typename scalar_t> class ScalarType {};

csrc/custom_marlin/gptq_marlin/gptq_marlin_repack.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ static constexpr int repack_threads = 256;
99
static constexpr int tile_k_size = tile_size;
1010
static constexpr int tile_n_size = tile_k_size * 4;
1111

12-
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
12+
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800) || defined(__HIP_PLATFORM_AMD__)
1313

1414
template <int const num_threads, int const num_bits, bool const has_perm>
1515
__global__ void marlin_repack_kernel(

ktransformers/operators/linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from torch import Tensor, nn
1717
if not torch.xpu.is_available():
1818
import KTransformersOps
19-
import vLLMMarlin
2019
from ktransformers.util.custom_loader import GGUFLoader, SafeTensorLoader
2120
from ktransformers.util.utils import InferenceState
2221
if not torch.xpu.is_available():
@@ -520,6 +519,7 @@ def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor = None) -> torch.Ten
520519
# padding x.shape[0] to avoid CUDA illegal memory access error
521520
x, orig_size_m = self._pad_input(x)
522521

522+
import vLLMMarlin
523523
x = vLLMMarlin.gptq_marlin_gemm(
524524
x,
525525
self.marlin_q_w,

0 commit comments

Comments
 (0)