@@ -21,18 +21,11 @@ def copy_kernel(
21
21
block_start = pid * BLOCK_SIZE
22
22
offsets = block_start + tl .arange (0 , BLOCK_SIZE )
23
23
mask = offsets < BLOCK_SIZE
24
-
24
+
25
25
for target_rank in range (num_ranks ):
26
- src_data = data + BLOCK_SIZE * cur_rank
26
+ src_data = data + BLOCK_SIZE * cur_rank
27
27
dest_data = results + BLOCK_SIZE * target_rank
28
- iris .copy (
29
- src_data + offsets ,
30
- dest_data + offsets ,
31
- cur_rank ,
32
- target_rank ,
33
- heap_bases ,
34
- mask
35
- )
28
+ iris .copy (src_data + offsets , dest_data + offsets , cur_rank , target_rank , heap_bases , mask )
36
29
37
30
38
31
@pytest .mark .parametrize (
@@ -53,7 +46,7 @@ def copy_kernel(
53
46
32 ,
54
47
],
55
48
)
56
- def test_copy_get_semantics (dtype , BLOCK_SIZE ):
49
+ def test_copy_get_semantics (dtype , BLOCK_SIZE ):
57
50
shmem = iris .iris (1 << 20 )
58
51
num_ranks = shmem .get_num_ranks ()
59
52
heap_bases = shmem .get_heap_bases ()
@@ -66,24 +59,17 @@ def test_copy_get_semantics(dtype, BLOCK_SIZE):
66
59
67
60
results = shmem .zeros ((num_ranks , BLOCK_SIZE ), dtype = dtype )
68
61
grid = lambda meta : (1 ,)
69
- copy_kernel [grid ](
70
- data ,
71
- results ,
72
- cur_rank ,
73
- num_ranks ,
74
- BLOCK_SIZE ,
75
- heap_bases
76
- )
62
+ copy_kernel [grid ](data , results , cur_rank , num_ranks , BLOCK_SIZE , heap_bases )
77
63
shmem .barrier ()
78
64
79
- expected = shmem .zeros ((num_ranks , BLOCK_SIZE ), dtype = dtype )
65
+ expected = shmem .zeros ((num_ranks , BLOCK_SIZE ), dtype = dtype )
80
66
for rank_id in range (num_ranks ):
81
67
expected [rank_id , :] = (rank_id + num_ranks ) * (cur_rank + 1 )
82
-
68
+
83
69
try :
84
70
torch .testing .assert_close (results , expected , rtol = 0 , atol = 0 )
85
71
except AssertionError as e :
86
72
print (e )
87
73
print ("Expected:" , expected )
88
74
print ("Actual:" , results )
89
- raise
75
+ raise
0 commit comments