@@ -19,18 +19,14 @@ def load_kernel(
1919):
2020 pid = tl .program_id (0 )
2121
22+ partner = int ((source_rank + num_ranks // 2 ) % num_ranks )
2223 # Compute start index of this block
2324 block_start = pid * BLOCK_SIZE
2425 offsets = block_start + tl .arange (0 , BLOCK_SIZE )
2526
2627 # Guard for out-of-bounds accesses
2728 mask = offsets < BLOCK_SIZE
28-
29- result = tl .zeros ([BLOCK_SIZE ], dtype = data .type .element_ty )
30- for target_rank in range (num_ranks ):
31- result += iris .load (data + offsets , source_rank , target_rank , heap_bases , mask = mask )
32-
33- # Store data to result buffer
29+ result = iris .load (data + offsets , source_rank , partner , heap_bases , mask = mask )
3430 tl .store (results + offsets , result , mask = mask )
3531
3632
@@ -58,16 +54,17 @@ def test_load_api(dtype, BLOCK_SIZE):
5854 num_ranks = shmem .get_num_ranks ()
5955 heap_bases = shmem .get_heap_bases ()
6056 source_rank = shmem .get_rank ()
57+ partner = int ((source_rank + num_ranks // 2 ) % num_ranks )
6158
62- data = shmem .ones ( BLOCK_SIZE , dtype = dtype )
59+ data = shmem .full (( BLOCK_SIZE ,), source_rank , dtype = dtype )
6360 results = shmem .zeros_like (data )
6461
6562 grid = lambda meta : (1 ,)
6663 load_kernel [grid ](data , results , source_rank , num_ranks , BLOCK_SIZE , heap_bases )
6764 shmem .barrier ()
6865
6966 # Verify the result
70- expected = torch .ones (BLOCK_SIZE , dtype = dtype , device = "cuda" ) * num_ranks
67+ expected = torch .ones (BLOCK_SIZE , dtype = dtype , device = "cuda" ) * partner
7168
7269 try :
7370 torch .testing .assert_close (results , expected , rtol = 0 , atol = 0 )
0 commit comments