Skip to content

Commit 53f5eb7

Browse files
Apply Ruff auto-fixes
1 parent 378a3fc commit 53f5eb7

File tree

2 files changed

+39
-29
lines changed

2 files changed

+39
-29
lines changed

iris/cuda.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,9 @@ def get_cu_count(device_id=None):
8282
cudaDeviceAttributeMultiprocessorCount = 16
8383
cu_count = ctypes.c_int()
8484

85-
cuda_try(cuda_runtime.cudaDeviceGetAttribute(ctypes.byref(cu_count), cudaDeviceAttributeMultiprocessorCount, device_id))
85+
cuda_try(
86+
cuda_runtime.cudaDeviceGetAttribute(ctypes.byref(cu_count), cudaDeviceAttributeMultiprocessorCount, device_id)
87+
)
8688

8789
return cu_count.value
8890

@@ -107,9 +109,7 @@ def get_cu_count(device_id=None):
107109
def get_wall_clock_rate(device_id):
108110
cudaDevAttrMemoryClockRate = 36
109111
wall_clock_rate = ctypes.c_int()
110-
status = cuda_runtime.cudaDeviceGetAttribute(
111-
ctypes.byref(wall_clock_rate), cudaDevAttrMemoryClockRate, device_id
112-
)
112+
status = cuda_runtime.cudaDeviceGetAttribute(ctypes.byref(wall_clock_rate), cudaDevAttrMemoryClockRate, device_id)
113113
cuda_try(status)
114114
return wall_clock_rate.value
115115

tests/examples/test_load_latency.py

Lines changed: 35 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from mpi4py import MPI
1212
# from examples.common.utils import read_realtime
1313

14+
1415
@triton.jit
1516
def read_realtime():
1617
tmp = tl.inline_asm_elementwise(
@@ -23,21 +24,25 @@ def read_realtime():
2324
)
2425
return tmp
2526

27+
2628
@triton.jit()
2729
def gather_latencies(
28-
local_latency,
29-
global_latency,
30-
curr_rank,
31-
num_ranks,
32-
BLOCK_SIZE: tl.constexpr,
33-
heap_bases: tl.tensor
30+
local_latency, global_latency, curr_rank, num_ranks, BLOCK_SIZE: tl.constexpr, heap_bases: tl.tensor
3431
):
3532
pid = tl.program_id(0)
3633
block_start = pid * BLOCK_SIZE
3734
offsets = block_start + tl.arange(0, BLOCK_SIZE)
3835

3936
latency_mask = offsets < num_ranks
40-
iris.put(local_latency + offsets, global_latency + curr_rank * num_ranks + offsets, curr_rank, 0, heap_bases, mask=latency_mask)
37+
iris.put(
38+
local_latency + offsets,
39+
global_latency + curr_rank * num_ranks + offsets,
40+
curr_rank,
41+
0,
42+
heap_bases,
43+
mask=latency_mask,
44+
)
45+
4146

4247
@triton.jit()
4348
def ping_pong(
@@ -100,9 +105,9 @@ def ping_pong(
100105
# ],
101106
# )
102107

103-
#def test_load_bench(dtype, heap_size):
108+
# def test_load_bench(dtype, heap_size):
104109
if __name__ == "__main__":
105-
dtype = torch.int32
110+
dtype = torch.int32
106111
heap_size = 1 << 32
107112
shmem = iris.iris(heap_size)
108113
num_ranks = shmem.get_num_ranks()
@@ -115,36 +120,42 @@ def ping_pong(
115120
iter = 1
116121
skip = 1
117122
mm_begin_timestamp = torch.zeros((num_ranks, BLOCK_SIZE), dtype=torch.int64, device="cuda")
118-
mm_end_timestamp = torch.zeros((num_ranks, BLOCK_SIZE), dtype=torch.int64, device="cuda")
119-
120-
local_latency = torch.zeros((num_ranks), dtype=torch.float32, device="cuda")
121-
latency_matrix = shmem.zeros((num_ranks, num_ranks), dtype=torch.float32,device="cuda")
123+
mm_end_timestamp = torch.zeros((num_ranks, BLOCK_SIZE), dtype=torch.int64, device="cuda")
122124

125+
local_latency = torch.zeros((num_ranks), dtype=torch.float32, device="cuda")
126+
latency_matrix = shmem.zeros((num_ranks, num_ranks), dtype=torch.float32, device="cuda")
123127

124128
source_buffer = shmem.ones(BUFFER_LEN, dtype=dtype)
125129
result_buffer = shmem.zeros_like(source_buffer)
126-
flag = shmem.ones(1, dtype=dtype)
130+
flag = shmem.ones(1, dtype=dtype)
127131

128132
grid = lambda meta: (1,)
129133
for source_rank in range(num_ranks):
130134
for destination_rank in range(num_ranks):
131135
if source_rank != destination_rank and cur_rank in [source_rank, destination_rank]:
132136
print(source_rank, destination_rank)
133-
ping_pong[grid](source_buffer,
134-
result_buffer, BUFFER_LEN,
135-
skip, iter,
136-
flag,
137-
source_rank, destination_rank,
138-
BLOCK_SIZE,
139-
heap_bases,
140-
mm_begin_timestamp,
141-
mm_end_timestamp)
137+
ping_pong[grid](
138+
source_buffer,
139+
result_buffer,
140+
BUFFER_LEN,
141+
skip,
142+
iter,
143+
flag,
144+
source_rank,
145+
destination_rank,
146+
BLOCK_SIZE,
147+
heap_bases,
148+
mm_begin_timestamp,
149+
mm_end_timestamp,
150+
)
142151
shmem.barrier()
143152
torch.cuda.synchronize()
144153
MPI.COMM_WORLD.Barrier()
145154

146155
for destination_rank in range(num_ranks):
147-
local_latency[destination_rank] = (mm_end_timestamp.cpu()[destination_rank] - mm_begin_timestamp.cpu()[destination_rank]) / iter
156+
local_latency[destination_rank] = (
157+
mm_end_timestamp.cpu()[destination_rank] - mm_begin_timestamp.cpu()[destination_rank]
158+
) / iter
148159

149160
# gather_latencies[grid](local_latency, latency_matrix, cur_rank, num_ranks, BLOCK_SIZE, heap_bases)
150161
# shmem.barrier()
@@ -160,7 +171,6 @@ def ping_pong(
160171
# line = f"R{i}\t" + "\t".join(row_entries) + "\n"
161172
# f.write(line)
162173

163-
164174
# if cur_rank == 0:
165175
# print("\nLatency measurements (raw timer ticks and per-iteration average):")
166176
# # for i in range(num_ranks):

0 commit comments

Comments
 (0)