Skip to content

Commit 877df70

Browse files
committed
Feat: Impl copy
1 parent 3c54277 commit 877df70

File tree

6 files changed

+78
-7
lines changed

6 files changed

+78
-7
lines changed

iris/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
load,
3434
store,
3535
copy,
36+
get,
37+
put,
3638
atomic_add,
3739
atomic_cas,
3840
atomic_xchg,
@@ -86,6 +88,8 @@
8688
"load",
8789
"store",
8890
"copy",
91+
"get",
92+
"put",
8993
"atomic_add",
9094
"atomic_cas",
9195
"atomic_xchg",
@@ -102,4 +106,4 @@
102106
"INFO",
103107
"WARNING",
104108
"ERROR",
105-
]
109+
]

iris/iris.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
>>> tensor = ctx.zeros(1024, 1024, dtype=torch.float32)
2323
"""
2424

25+
from iris.util import trap_if
2526
import triton
2627
import triton.language as tl
2728

@@ -1532,7 +1533,7 @@ def store(pointer, value, from_rank, to_rank, heap_bases, mask=None):
15321533

15331534

15341535
@triton.jit
1535-
def copy(src_ptr, dst_ptr, from_rank, to_rank, heap_bases, mask=None):
1536+
def copy(src_ptr, dst_ptr, from_rank, to_rank, cur_rank, heap_bases, mask=None):
15361537
"""
15371538
Copies data from the specified rank's memory into the destination rank's memory.
15381539
This function performs the transfer by translating src_ptr from the from_rank's address
@@ -1548,6 +1549,51 @@ def copy(src_ptr, dst_ptr, from_rank, to_rank, heap_bases, mask=None):
15481549
heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks.
15491550
mask (Block of triton.int1, optional): If mask[idx] is false, do not load from the translated src_ptr[idx] and do not store to dst_ptr[idx]. Defaults to None.
15501551
1552+
Returns:
1553+
None
1554+
"""
1555+
1556+
trap_if((cur_rank != from_rank) and (cur_rank != to_rank))
1557+
1558+
cur_base = tl.load(heap_bases + cur_rank)
1559+
1560+
from_base = tl.load(heap_bases + from_rank)
1561+
to_base = tl.load(heap_bases + to_rank)
1562+
1563+
src_ptr_int = tl.cast(src_ptr, tl.uint64)
1564+
src_offset = src_ptr_int - cur_base
1565+
1566+
dst_ptr_int = tl.cast(dst_ptr, tl.uint64)
1567+
dst_offset = dst_ptr_int - cur_base
1568+
1569+
from_base_byte = tl.cast(from_base, tl.pointer_type(tl.int8))
1570+
to_base_byte = tl.cast(to_base , tl.pointer_type(tl.int8))
1571+
1572+
translated_src = tl.cast(from_base_byte + src_offset, src_ptr.dtype)
1573+
translated_dst = tl.cast(to_base_byte + dst_offset, src_ptr.dtype)
1574+
1575+
data = tl.load(translated_src, mask=mask)
1576+
tl.store(translated_dst, data, mask=mask)
1577+
1578+
1579+
@triton.jit
1580+
def get(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None):
1581+
"""
1582+
Copies data from the specified rank's memory to the current rank's local memory.
1583+
1584+
This function performs a memory read operation by translating the from_ptr
1585+
from the current rank's address space to the from_rank's address space, loading data
1586+
from the from_rank memory location, and storing it to the local to_ptr.
1587+
If the from_rank is the same as the current rank, this function performs a local copy operation.
1588+
1589+
Args:
1590+
from_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's address space that will be translated to the from_rank's address space. Must be the current rank where the pointer is local.
1591+
to_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's local memory where the data will be stored.
1592+
from_rank (int): The from_rank ID from which to read the data.
1593+
to_rank (int): The current rank ID where the data will be stored.
1594+
heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks.
1595+
mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None.
1596+
15511597
Returns:
15521598
None
15531599

iris/util.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,3 +142,24 @@ def do_bench(
142142

143143
times = [s.elapsed_time(e) for s, e in zip(start_event, end_event)]
144144
return _summarize_statistics(times, quantiles, return_mode)
145+
146+
147+
@triton.jit
148+
def trap_if(cond):
149+
drv = tl.zeros([1], dtype=tl.uint32)
150+
cond_u32 = tl.where(cond, drv, drv + 1)
151+
if tl.program_id(0) == 0:
152+
tl.inline_asm_elementwise(
153+
asm="""
154+
s_cmp_lg_u32 $1, 0
155+
s_cbranch_scc1 0f
156+
s_trap 2
157+
0:
158+
""",
159+
constraints="=v,s",
160+
args=[cond_u32],
161+
dtype=tl.uint32,
162+
is_pure=False,
163+
pack=1,
164+
)
165+

tests/unittests/test_copy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def copy_kernel(
2525
for target_rank in range(num_ranks):
2626
src_data = data + BLOCK_SIZE * cur_rank
2727
dest_data = results + BLOCK_SIZE * target_rank
28-
iris.copy(src_data + offsets, dest_data + offsets, cur_rank, target_rank, heap_bases, mask)
28+
iris.copy(src_data + offsets, dest_data + offsets, target_rank, cur_rank, cur_rank, heap_bases, mask)
2929

3030

3131
@pytest.mark.parametrize(

tests/unittests/test_get.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def get_kernel(
3131
# Loop over all ranks, get the stored data.
3232
# load to local register, accumulate.
3333
for target_rank in range(num_ranks):
34-
iris.copy(data + offsets, results + offsets, cur_rank, target_rank, heap_bases, mask=mask)
34+
iris.get(data + offsets, results + offsets, cur_rank, target_rank, heap_bases, mask=mask)
3535
acc += tl.load(results + offsets, mask=mask)
3636

3737
# Store the accumulated value back to the output.
@@ -81,4 +81,4 @@ def test_get_api(dtype, BLOCK_SIZE):
8181
print(e)
8282
print("Expected:", expected)
8383
print("Actual:", results)
84-
raise
84+
raise

tests/unittests/test_put.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def put_kernel(
2929
# Put data in all ranks
3030
# Doesn't matter which rank stores at the end, the data should all be the same at the end.
3131
for target_rank in range(num_ranks):
32-
iris.copy(data + offsets, results + offsets, cur_rank, target_rank, heap_bases, mask=mask)
32+
iris.put(data + offsets, results + offsets, cur_rank, target_rank, heap_bases, mask=mask)
3333

3434

3535
@pytest.mark.parametrize(
@@ -75,4 +75,4 @@ def test_put_api(dtype, BLOCK_SIZE):
7575
print(e)
7676
print("Expected:", expected)
7777
print("Actual:", results)
78-
raise
78+
raise

0 commit comments

Comments
 (0)