Skip to content

Commit d535696

Browse files
committed
Fix gemm example on CUDA 13.0.
`#[address_space(shared)] static mut` is used to model GPU shared memory. It's a bit weird. In particular, GPU shared memory is uninitialized, but `static mut` requires an initializer in Rust. `gemm` uses a zero initialize, but this initializer is ignored by NVVM. At least, it was in CUDA 12.x, but in CUDA 13.0 the `gemm` example fails with this error: ``` thread 'rustc' panicked at crates/rustc_codegen_nvvm/src/nvvm.rs:120:9: Malformed NVVM IR program rejected by libnvvm, dumping verifier log: error: Error: : Global Variable `_ZN12gemm_kernels10gemm_tiled10gemm_tiled6TILE_A17hc9c66e758c373a7eE': context: @_ZN12gemm_kernels10gemm_tiled10gemm_tiled6TILE_A17hc9c66e758c373a7eE = internal unnamed_addr addrspace(3) global <{ [1024 x i8] }> zeroinitializer, align 4 Shared variables can't be initialized ``` This memory looks like it's initialized to zero but isn't, and then is written and read normally. This is incredibly dodgy and very likely UB. The proper way to deal with uninitialized memory in Rust is with `MaybeUninit`, and there are strict rules around its used, e.g. writes must be done with `write` and `assume_init` must be used values after they are written. This commit changes `gemm` to use `MaybeUninit` for the shared memory. This fixes the error on CUDA 13.0 and the example runs correctly. (This is the only executed use of GPU shared memory in rust-cuda. There is a `shared_array!` macro defined but it's only used in a compiletest where it is compiled but not run. That macro is extremely dubious but I will deal with it in a separate PR because it's not necessary to get CUDA 13.0 working.)
1 parent b0c4c71 commit d535696

File tree

2 files changed

+19
-7
lines changed

2 files changed

+19
-7
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

examples/cuda/gemm/kernels/src/gemm_tiled.rs

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use core::mem::MaybeUninit;
12
use cuda_std::address_space;
23
use cuda_std::kernel;
34
use cuda_std::thread;
@@ -38,11 +39,19 @@ pub unsafe fn gemm_tiled(
3839
beta: f32,
3940
) {
4041
const TILE_SIZE: usize = 16;
42+
const TILE_SIZE_2D: usize = TILE_SIZE * TILE_SIZE;
4143

44+
// Shared GPU memory is modelled with `#[address_space(shared)] static mut`. Unlike normal
45+
// `static mut`, it is not initialized, and only exists for the duration of the kernel's
46+
// (multi-)execution. Because it is not initialized, it must be marked with `MaybeUninit`,
47+
// written with `write` (in unsafe blocks because writing a `static mut` is unsafe), and
48+
// subsequently read with `assume_init`.
4249
#[address_space(shared)]
43-
static mut TILE_A: [f32; TILE_SIZE * TILE_SIZE] = [0.; TILE_SIZE * TILE_SIZE];
50+
static mut TILE_A: [MaybeUninit<f32>; TILE_SIZE_2D] =
51+
[const { MaybeUninit::uninit() }; TILE_SIZE_2D];
4452
#[address_space(shared)]
45-
static mut TILE_B: [f32; TILE_SIZE * TILE_SIZE] = [0.; TILE_SIZE * TILE_SIZE];
53+
static mut TILE_B: [MaybeUninit<f32>; TILE_SIZE_2D] =
54+
[const { MaybeUninit::uninit() }; TILE_SIZE_2D];
4655

4756
// Thread indices within the block.
4857
let tx = thread::thread_idx_x() as usize;
@@ -57,20 +66,22 @@ pub unsafe fn gemm_tiled(
5766
for kk in (0..k).step_by(TILE_SIZE) {
5867
// Collaborative loading of tiles into shared memory.
5968
if row < m && (kk + tx) < k {
60-
unsafe { TILE_A[ty * TILE_SIZE + tx] = mat_a[row * k + (kk + tx)] };
69+
unsafe { TILE_A[ty * TILE_SIZE + tx].write(mat_a[row * k + (kk + tx)]); }
6170
} else {
62-
unsafe { TILE_A[ty * TILE_SIZE + tx] = 0.0f32 };
71+
unsafe { TILE_A[ty * TILE_SIZE + tx].write(0.0f32); }
6372
}
6473
if col < n && (kk + ty) < k {
65-
unsafe { TILE_B[ty * TILE_SIZE + tx] = mat_b[(kk + ty) * n + col] };
74+
unsafe { TILE_B[ty * TILE_SIZE + tx].write(mat_b[(kk + ty) * n + col]); }
6675
} else {
67-
unsafe { TILE_B[ty * TILE_SIZE + tx] = 0.0f32 };
76+
unsafe { TILE_B[ty * TILE_SIZE + tx].write(0.0f32); }
6877
}
6978
thread::sync_threads();
7079

7180
// Perform the computation on the tile.
7281
for i in 0..TILE_SIZE {
73-
sum += unsafe { TILE_A[ty * TILE_SIZE + i] * TILE_B[i * TILE_SIZE + tx] };
82+
sum += unsafe {
83+
TILE_A[ty * TILE_SIZE + i].assume_init() * TILE_B[i * TILE_SIZE + tx].assume_init()
84+
};
7485
}
7586
thread::sync_threads();
7687
}

0 commit comments

Comments
 (0)