Skip to content

Commit fcc41e8

Browse files
committed
remoted trap flag
1 parent 877df70 commit fcc41e8

File tree

2 files changed

+8
-24
lines changed

2 files changed

+8
-24
lines changed

iris/iris.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
>>> tensor = ctx.zeros(1024, 1024, dtype=torch.float32)
2323
"""
2424

25-
from iris.util import trap_if
2625
import triton
2726
import triton.language as tl
2827

@@ -1540,6 +1539,7 @@ def copy(src_ptr, dst_ptr, from_rank, to_rank, cur_rank, heap_bases, mask=None):
15401539
space to the to_rank's address space, performing a masked load from the translated
15411540
source, and storing the loaded data to dst_ptr in the to_rank memory location.
15421541
If from_rank and to_rank are the same, this function performs a local copy operation.
1542+
It is undefined behaviour if neither from_rank nor to_rank is the cur_rank.
15431543
15441544
Args:
15451545
src_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the from_rank's local memory from which to read data.
@@ -1551,9 +1551,14 @@ def copy(src_ptr, dst_ptr, from_rank, to_rank, cur_rank, heap_bases, mask=None):
15511551
15521552
Returns:
15531553
None
1554-
"""
15551554
1556-
trap_if((cur_rank != from_rank) and (cur_rank != to_rank))
1555+
Example:
1556+
>>> @triton.jit
1557+
>>> def kernel(remote_ptr, local_ptr, heap_bases):
1558+
>>> from_rank = 1
1559+
>>> to_rank = 0
1560+
>>> iris.copy(remote_ptr, local_ptr, from_rank, to_rank, to_rank, heap_bases)
1561+
"""
15571562

15581563
cur_base = tl.load(heap_bases + cur_rank)
15591564

iris/util.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -142,24 +142,3 @@ def do_bench(
142142

143143
times = [s.elapsed_time(e) for s, e in zip(start_event, end_event)]
144144
return _summarize_statistics(times, quantiles, return_mode)
145-
146-
147-
@triton.jit
148-
def trap_if(cond):
149-
drv = tl.zeros([1], dtype=tl.uint32)
150-
cond_u32 = tl.where(cond, drv, drv + 1)
151-
if tl.program_id(0) == 0:
152-
tl.inline_asm_elementwise(
153-
asm="""
154-
s_cmp_lg_u32 $1, 0
155-
s_cbranch_scc1 0f
156-
s_trap 2
157-
0:
158-
""",
159-
constraints="=v,s",
160-
args=[cond_u32],
161-
dtype=tl.uint32,
162-
is_pure=False,
163-
pack=1,
164-
)
165-

0 commit comments

Comments
 (0)