From 73ef5b9dc8b8108bb5d3a7ca1ae1213cc02d71bb Mon Sep 17 00:00:00 2001 From: Jesse Createthis Date: Wed, 13 Aug 2025 12:11:33 +0000 Subject: [PATCH 1/4] Add compile-time flag GGML_CUDA_ALLOW_LARGE_TENSORS to bypass INT_MAX check in ggml_cuda_cpy --- ggml/src/ggml-cuda/cpy.cu | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index f9bb025643ca2..df609d4117c18 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -282,8 +282,13 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg const int64_t ne = ggml_nelements(src0); GGML_ASSERT(ne == ggml_nelements(src1)); + #if defined(GGML_CUDA_ALLOW_LARGE_TENSORS) + // No INT_MAX limit – ggml_nbytes may exceed 2GB on large contexts. + // The underlying cudaMemcpyAsync can handle size_t lengths. + #else GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX); GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX); + #endif const int64_t ne00 = src0->ne[0]; const int64_t ne01 = src0->ne[1]; From d3ea7d27a3e706dcce7b39f37e59e253b18a8266 Mon Sep 17 00:00:00 2001 From: Jesse CreateThis Date: Wed, 13 Aug 2025 10:51:56 -0400 Subject: [PATCH 2/4] R1-0528's attempt to implement this. I doubt this code works. User beware. --- ggml/src/ggml-cuda/cpy.cu | 120 +++++++++++++++++++++++++++++++------- 1 file changed, 99 insertions(+), 21 deletions(-) diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index df609d4117c18..1a06f9f77c765 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -141,69 +141,147 @@ void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_des template static void ggml_cpy_flt_cuda( - const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + #if defined(GGML_CUDA_ALLOW_LARGE_TENSORS) + const int64_t max_chunk = INT_MAX; + for (int64_t offset = 0; offset < ne; offset += max_chunk) { + const int64_t chunk = (ne - offset) < max_chunk ? (ne - offset) : max_chunk; + const int num_blocks = (chunk + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; + cpy_flt><<>> + (cx + offset * sizeof(src_t), cdst + offset * sizeof(dst_t), chunk, + ne00, ne01, ne02, nb00, nb01, nb02, nb03, + ne10, ne11, ne12, nb10, nb11, nb12, nb13, + cdst_indirect, graph_cpynode_index++); + } + #else const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; cpy_flt><<>> (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); + #endif } static void ggml_cpy_f32_q8_0_cuda( - const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { GGML_ASSERT(ne % QK8_0 == 0); + #if defined(GGML_CUDA_ALLOW_LARGE_TENSORS) + const int64_t max_chunk = INT_MAX; + for (int64_t offset = 0; offset < ne; offset += max_chunk) { + const int64_t chunk = (ne - offset) < max_chunk ? (ne - offset) : max_chunk; + const int64_t chunk_blocks = chunk / QK8_0; + cpy_f32_q<<>> + (cx + offset * sizeof(float), cdst + (offset / QK8_0) * sizeof(block_q8_0), chunk, + ne00, ne01, ne02, nb00, nb01, nb02, nb03, + ne10, ne11, ne12, nb10, nb11, nb12, nb13, + cdst_indirect, graph_cpynode_index++); + } + #else const int num_blocks = ne / QK8_0; cpy_f32_q<<>> (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); + #endif } static void ggml_cpy_q8_0_f32_cuda( - const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + #if defined(GGML_CUDA_ALLOW_LARGE_TENSORS) + const int64_t max_chunk = INT_MAX; + for (int64_t offset = 0; offset < ne; offset += max_chunk) { + const int64_t chunk = (ne - offset) < max_chunk ? (ne - offset) : max_chunk; + const int64_t chunk_blocks = chunk; + cpy_q_f32<<>> + (cx + (offset / QK8_0) * sizeof(block_q8_0), cdst + offset * sizeof(float), chunk, + ne00, ne01, ne02, nb00, nb01, nb02, nb03, + ne10, ne11, ne12, nb10, nb11, nb12, nb13, + cdst_indirect, graph_cpynode_index++); + } + #else const int num_blocks = ne; cpy_q_f32<<>> (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); + #endif } static void ggml_cpy_f32_q4_0_cuda( - const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { GGML_ASSERT(ne % QK4_0 == 0); + #if defined(GGML_CUDA_ALLOW_LARGE_TENSORS) + const int64_t max_chunk = INT_MAX; + for (int64_t offset = 0; offset < ne; offset += max_chunk) { + const int64_t chunk = (ne - offset) < max_chunk ? (ne - offset) : max_chunk; + const int64_t chunk_blocks = chunk / QK4_0; + cpy_f32_q<<>> + (cx + offset * sizeof(float), cdst + (offset / QK4_0) * sizeof(block_q4_0), chunk, + ne00, ne01, ne02, nb00, nb01, nb02, nb03, + ne10, ne11, ne12, nb10, nb11, nb12, nb13, + cdst_indirect, graph_cpynode_index++); + } + #else const int num_blocks = ne / QK4_0; cpy_f32_q<<>> (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); + #endif } static void ggml_cpy_q4_0_f32_cuda( - const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, - const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, - const int nb10, const int nb11, const int nb12, const int nb13, + const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, + const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, + const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + #if defined(GGML_CUDA_ALLOW_LARGE_TENSORS) + const int64_t max_chunk = INT_MAX; + for (int64_t offset = 0; offset < ne; offset += max_chunk) { + const int64_t chunk = (ne - offset) < max_chunk ? (ne - offset) : max_chunk; + const int64_t chunk_blocks = chunk; + cpy_q_f32, QK4_0><<>>( + cx + (offset / QK4_0) * sizeof(block_q4_0), cdst + offset * sizeof(float), chunk, + ne00, ne01, ne02, nb00, nb01, nb02, nb03, + ne10, ne11, ne12, nb10, nb11, nb12, nb13, + cdst_indirect, graph_cpynode_index++); + } + #else const int num_blocks = ne; cpy_q_f32, QK4_0><<>>( cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, - ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); + ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); + #endif } static void ggml_cpy_f32_q4_1_cuda( - const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { GGML_ASSERT(ne % QK4_1 == 0); + #if defined(GGML_CUDA_ALLOW_LARGE_TENSORS) + const int64_t max_chunk = INT_MAX; + for (int64_t offset = 0; offset < ne; offset += max_chunk) { + const int64_t chunk = (ne - offset) < max_chunk ? (ne - offset) : max_chunk; + const int64_t chunk_blocks = chunk / QK4_1; + cpy_f32_q<<>> + (cx + offset * sizeof(float), cdst + (offset / QK4_1) * sizeof(block_q4_1), chunk, + ne00, ne01, ne02, nb00, nb01, nb02, nb03, + ne10, ne11, ne12, nb10, nb11, nb12, nb13, + cdst_indirect, graph_cpynode_index++); + } + #else const int num_blocks = ne / QK4_1; cpy_f32_q<<>> (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); + #endif } static void ggml_cpy_q4_1_f32_cuda( From 39fbbb8ae4a11a4836b65b3896a65608776b4611 Mon Sep 17 00:00:00 2001 From: Jesse CreateThis Date: Wed, 13 Aug 2025 11:19:53 -0400 Subject: [PATCH 3/4] New assertions for GGML_CUDA_ALLOW_LARGE_TENSORS upper bounds, coded by Qwen3-Coder-480B-A35B-Instruct-1M-GGUF --- ggml/src/ggml-cuda/cpy.cu | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index 1a06f9f77c765..2462f062e7c62 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -1,6 +1,7 @@ #include "cpy.cuh" #include "dequantize.cuh" #include "cpy-utils.cuh" +#include // For SIZE_MAX #if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY) #include "ggml-musa/mudnn.cuh" #endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY @@ -363,6 +364,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg #if defined(GGML_CUDA_ALLOW_LARGE_TENSORS) // No INT_MAX limit – ggml_nbytes may exceed 2GB on large contexts. // The underlying cudaMemcpyAsync can handle size_t lengths. + GGML_ASSERT(ggml_nbytes(src0) <= SIZE_MAX / 4); // Reasonable upper bound with safety margin + GGML_ASSERT(ggml_nbytes(src1) <= SIZE_MAX / 4); // Reasonable upper bound with safety margin #else GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX); GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX); From e40e6a6e7d936554601f2cecdd50f4b6d046c6e0 Mon Sep 17 00:00:00 2001 From: Jesse CreateThis Date: Wed, 13 Aug 2025 17:14:48 +0000 Subject: [PATCH 4/4] Add compile option GGML_CUDA_ALLOW_LARGE_TENSORS and define macro for CUDA large tensor support This change by gpt-oss-120b-mxfp4. --- ggml/CMakeLists.txt | 1 + ggml/src/ggml-cuda/CMakeLists.txt | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 1fb7abeaf088f..d8bccac70fc24 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -165,6 +165,7 @@ option(GGML_CUDA_NO_PEER_COPY "ggml: do not use peer to peer copie option(GGML_CUDA_NO_VMM "ggml: do not try to use CUDA VMM" OFF) option(GGML_CUDA_FA "ggml: compile ggml FlashAttention CUDA kernels" ON) option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashAttention" OFF) +option(GGML_CUDA_ALLOW_LARGE_TENSORS "ggml: allow large tensors for CUDA (disable INT_MAX check)" OFF) option(GGML_CUDA_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" ${GGML_CUDA_GRAPHS_DEFAULT}) set (GGML_CUDA_COMPRESSION_MODE "size" CACHE STRING "ggml: cuda link binary compression mode; requires cuda 12.8+") diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index 98ed29bc9c12f..ca5b5dd2036c3 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -99,6 +99,10 @@ if (CUDAToolkit_FOUND) add_compile_definitions(GGML_CUDA_NO_PEER_COPY) endif() + if (GGML_CUDA_ALLOW_LARGE_TENSORS) + add_compile_definitions(GGML_CUDA_ALLOW_LARGE_TENSORS) + endif() + if (GGML_STATIC) if (WIN32) # As of 12.3.1 CUDA Toolkit for Windows does not offer a static cublas library