1111from mpi4py import MPI
1212# from examples.common.utils import read_realtime
1313
14+
1415@triton .jit
1516def read_realtime ():
1617 tmp = tl .inline_asm_elementwise (
@@ -23,21 +24,25 @@ def read_realtime():
2324 )
2425 return tmp
2526
27+
2628@triton .jit ()
2729def gather_latencies (
28- local_latency ,
29- global_latency ,
30- curr_rank ,
31- num_ranks ,
32- BLOCK_SIZE : tl .constexpr ,
33- heap_bases : tl .tensor
30+ local_latency , global_latency , curr_rank , num_ranks , BLOCK_SIZE : tl .constexpr , heap_bases : tl .tensor
3431):
3532 pid = tl .program_id (0 )
3633 block_start = pid * BLOCK_SIZE
3734 offsets = block_start + tl .arange (0 , BLOCK_SIZE )
3835
3936 latency_mask = offsets < num_ranks
40- iris .put (local_latency + offsets , global_latency + curr_rank * num_ranks + offsets , curr_rank , 0 , heap_bases , mask = latency_mask )
37+ iris .put (
38+ local_latency + offsets ,
39+ global_latency + curr_rank * num_ranks + offsets ,
40+ curr_rank ,
41+ 0 ,
42+ heap_bases ,
43+ mask = latency_mask ,
44+ )
45+
4146
4247@triton .jit ()
4348def ping_pong (
@@ -100,9 +105,9 @@ def ping_pong(
100105# ],
101106# )
102107
103- #def test_load_bench(dtype, heap_size):
108+ # def test_load_bench(dtype, heap_size):
104109if __name__ == "__main__" :
105- dtype = torch .int32
110+ dtype = torch .int32
106111 heap_size = 1 << 32
107112 shmem = iris .iris (heap_size )
108113 num_ranks = shmem .get_num_ranks ()
@@ -115,36 +120,42 @@ def ping_pong(
115120 iter = 1
116121 skip = 1
117122 mm_begin_timestamp = torch .zeros ((num_ranks , BLOCK_SIZE ), dtype = torch .int64 , device = "cuda" )
118- mm_end_timestamp = torch .zeros ((num_ranks , BLOCK_SIZE ), dtype = torch .int64 , device = "cuda" )
119-
120- local_latency = torch .zeros ((num_ranks ), dtype = torch .float32 , device = "cuda" )
121- latency_matrix = shmem .zeros ((num_ranks , num_ranks ), dtype = torch .float32 ,device = "cuda" )
123+ mm_end_timestamp = torch .zeros ((num_ranks , BLOCK_SIZE ), dtype = torch .int64 , device = "cuda" )
122124
125+ local_latency = torch .zeros ((num_ranks ), dtype = torch .float32 , device = "cuda" )
126+ latency_matrix = shmem .zeros ((num_ranks , num_ranks ), dtype = torch .float32 , device = "cuda" )
123127
124128 source_buffer = shmem .ones (BUFFER_LEN , dtype = dtype )
125129 result_buffer = shmem .zeros_like (source_buffer )
126- flag = shmem .ones (1 , dtype = dtype )
130+ flag = shmem .ones (1 , dtype = dtype )
127131
128132 grid = lambda meta : (1 ,)
129133 for source_rank in range (num_ranks ):
130134 for destination_rank in range (num_ranks ):
131135 if source_rank != destination_rank and cur_rank in [source_rank , destination_rank ]:
132136 print (source_rank , destination_rank )
133- ping_pong [grid ](source_buffer ,
134- result_buffer , BUFFER_LEN ,
135- skip , iter ,
136- flag ,
137- source_rank , destination_rank ,
138- BLOCK_SIZE ,
139- heap_bases ,
140- mm_begin_timestamp ,
141- mm_end_timestamp )
137+ ping_pong [grid ](
138+ source_buffer ,
139+ result_buffer ,
140+ BUFFER_LEN ,
141+ skip ,
142+ iter ,
143+ flag ,
144+ source_rank ,
145+ destination_rank ,
146+ BLOCK_SIZE ,
147+ heap_bases ,
148+ mm_begin_timestamp ,
149+ mm_end_timestamp ,
150+ )
142151 shmem .barrier ()
143152 torch .cuda .synchronize ()
144153 MPI .COMM_WORLD .Barrier ()
145154
146155 for destination_rank in range (num_ranks ):
147- local_latency [destination_rank ] = (mm_end_timestamp .cpu ()[destination_rank ] - mm_begin_timestamp .cpu ()[destination_rank ]) / iter
156+ local_latency [destination_rank ] = (
157+ mm_end_timestamp .cpu ()[destination_rank ] - mm_begin_timestamp .cpu ()[destination_rank ]
158+ ) / iter
148159
149160 # gather_latencies[grid](local_latency, latency_matrix, cur_rank, num_ranks, BLOCK_SIZE, heap_bases)
150161 # shmem.barrier()
@@ -160,7 +171,6 @@ def ping_pong(
160171 # line = f"R{i}\t" + "\t".join(row_entries) + "\n"
161172 # f.write(line)
162173
163-
164174 # if cur_rank == 0:
165175 # print("\nLatency measurements (raw timer ticks and per-iteration average):")
166176 # # for i in range(num_ranks):
0 commit comments