Skip to content

Commit 1031a5d

Browse files
committed
initial copy impl
1 parent d5a876d commit 1031a5d

File tree

3 files changed

+100
-0
lines changed

3 files changed

+100
-0
lines changed

iris/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
store,
1414
get,
1515
put,
16+
copy,
1617
atomic_add,
1718
atomic_sub,
1819
atomic_cas,
@@ -57,6 +58,7 @@
5758
"store",
5859
"get",
5960
"put",
61+
"copy",
6062
"atomic_add",
6163
"atomic_sub",
6264
"atomic_cas",

iris/iris.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,13 @@ def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None):
441441
tl.store(translated_to_ptr, data, mask=mask)
442442

443443

444+
@triton.jit
445+
def copy(dst_ptr, src_ptr, from_rank, to_rank, heap_bases, mask=None):
446+
translated_src = __translate(src_ptr, from_rank, to_rank, heap_bases)
447+
data = tl.load(translated_src, mask=mask)
448+
tl.store(dst_ptr, data, mask=mask)
449+
450+
444451
@triton.jit
445452
def atomic_add(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None):
446453
"""

tests/unittests/test_copy.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# SPDX-License-Identifier: MIT
2+
# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
3+
4+
import torch
5+
import triton
6+
import triton.language as tl
7+
import pytest
8+
import iris
9+
10+
11+
@triton.jit
12+
def copy_kernel(
13+
data,
14+
results,
15+
cur_rank: tl.constexpr,
16+
num_ranks: tl.constexpr,
17+
BLOCK_SIZE: tl.constexpr,
18+
heap_bases: tl.tensor,
19+
):
20+
pid = tl.program_id(0)
21+
block_start = pid * BLOCK_SIZE
22+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
23+
mask = offsets < BLOCK_SIZE
24+
25+
for target_rank in range(num_ranks):
26+
src_data = data + BLOCK_SIZE * cur_rank
27+
dest_data = results + BLOCK_SIZE * target_rank
28+
iris.copy(
29+
dest_data + offsets,
30+
src_data + offsets,
31+
cur_rank,
32+
target_rank,
33+
heap_bases,
34+
mask
35+
)
36+
37+
38+
@pytest.mark.parametrize(
39+
"dtype",
40+
[
41+
torch.int8,
42+
torch.float16,
43+
torch.bfloat16,
44+
torch.float32,
45+
],
46+
)
47+
@pytest.mark.parametrize(
48+
"BLOCK_SIZE",
49+
[
50+
1,
51+
8,
52+
16,
53+
32,
54+
],
55+
)
56+
def test_copy_get_semantics(dtype, BLOCK_SIZE):
57+
shmem = iris.iris(1 << 20)
58+
num_ranks = shmem.get_num_ranks()
59+
heap_bases = shmem.get_heap_bases()
60+
cur_rank = shmem.get_rank()
61+
62+
data = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=dtype)
63+
base = cur_rank + num_ranks
64+
for i in range(num_ranks):
65+
data[i, :] = base * (i + 1)
66+
67+
results = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=dtype)
68+
grid = lambda meta: (1,)
69+
copy_kernel[grid](
70+
data,
71+
results,
72+
cur_rank,
73+
num_ranks,
74+
BLOCK_SIZE,
75+
heap_bases
76+
)
77+
shmem.barrier()
78+
79+
expected = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=dtype)
80+
expected_2 = torch.zeros((num_ranks, BLOCK_SIZE), dtype=dtype, device="cuda")
81+
for rank_id in range(num_ranks):
82+
expected[rank_id, :] = 999999
83+
expected_2[rank_id, :] = (rank_id + num_ranks) * (cur_rank + 1)
84+
85+
try:
86+
torch.testing.assert_close(results, expected, rtol=0, atol=0)
87+
except AssertionError as e:
88+
print(e)
89+
print("Expected:", expected_2)
90+
print("Actual:", results)
91+
raise

0 commit comments

Comments
 (0)