1
+ # SPDX-License-Identifier: MIT
2
+ # Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
3
+
4
+ import torch
5
+ import triton
6
+ import triton .language as tl
7
+ import pytest
8
+ import iris
9
+
10
+
11
+ @triton .jit
12
+ def copy_kernel (
13
+ data ,
14
+ results ,
15
+ cur_rank : tl .constexpr ,
16
+ num_ranks : tl .constexpr ,
17
+ BLOCK_SIZE : tl .constexpr ,
18
+ heap_bases : tl .tensor ,
19
+ ):
20
+ pid = tl .program_id (0 )
21
+ block_start = pid * BLOCK_SIZE
22
+ offsets = block_start + tl .arange (0 , BLOCK_SIZE )
23
+ mask = offsets < BLOCK_SIZE
24
+
25
+ for target_rank in range (num_ranks ):
26
+ src_data = data + BLOCK_SIZE * cur_rank
27
+ dest_data = results + BLOCK_SIZE * target_rank
28
+ iris .copy (
29
+ dest_data + offsets ,
30
+ src_data + offsets ,
31
+ cur_rank ,
32
+ target_rank ,
33
+ heap_bases ,
34
+ mask
35
+ )
36
+
37
+
38
+ @pytest .mark .parametrize (
39
+ "dtype" ,
40
+ [
41
+ torch .int8 ,
42
+ torch .float16 ,
43
+ torch .bfloat16 ,
44
+ torch .float32 ,
45
+ ],
46
+ )
47
+ @pytest .mark .parametrize (
48
+ "BLOCK_SIZE" ,
49
+ [
50
+ 1 ,
51
+ 8 ,
52
+ 16 ,
53
+ 32 ,
54
+ ],
55
+ )
56
+ def test_copy_get_semantics (dtype , BLOCK_SIZE ):
57
+ shmem = iris .iris (1 << 20 )
58
+ num_ranks = shmem .get_num_ranks ()
59
+ heap_bases = shmem .get_heap_bases ()
60
+ cur_rank = shmem .get_rank ()
61
+
62
+ data = shmem .zeros ((num_ranks , BLOCK_SIZE ), dtype = dtype )
63
+ base = cur_rank + num_ranks
64
+ for i in range (num_ranks ):
65
+ data [i , :] = base * (i + 1 )
66
+
67
+ results = shmem .zeros ((num_ranks , BLOCK_SIZE ), dtype = dtype )
68
+ grid = lambda meta : (1 ,)
69
+ copy_kernel [grid ](
70
+ data ,
71
+ results ,
72
+ cur_rank ,
73
+ num_ranks ,
74
+ BLOCK_SIZE ,
75
+ heap_bases
76
+ )
77
+ shmem .barrier ()
78
+
79
+ expected = shmem .zeros ((num_ranks , BLOCK_SIZE ), dtype = dtype )
80
+ expected_2 = torch .zeros ((num_ranks , BLOCK_SIZE ), dtype = dtype , device = "cuda" )
81
+ for rank_id in range (num_ranks ):
82
+ expected [rank_id , :] = 999999
83
+ expected_2 [rank_id , :] = (rank_id + num_ranks ) * (cur_rank + 1 )
84
+
85
+ try :
86
+ torch .testing .assert_close (results , expected , rtol = 0 , atol = 0 )
87
+ except AssertionError as e :
88
+ print (e )
89
+ print ("Expected:" , expected_2 )
90
+ print ("Actual:" , results )
91
+ raise
0 commit comments