Skip to content

Commit 894410f

Browse files
Apply Ruff auto-fixes
1 parent c2ca89c commit 894410f

File tree

2 files changed

+9
-23
lines changed

2 files changed

+9
-23
lines changed

iris/iris.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ def get(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None):
407407
Returns:
408408
None
409409
"""
410-
copy(from_ptr, to_ptr, from_rank, to_rank , heap_bases, mask)
410+
copy(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask)
411411

412412

413413
@triton.jit

tests/unittests/test_copy.py

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,11 @@ def copy_kernel(
2121
block_start = pid * BLOCK_SIZE
2222
offsets = block_start + tl.arange(0, BLOCK_SIZE)
2323
mask = offsets < BLOCK_SIZE
24-
24+
2525
for target_rank in range(num_ranks):
26-
src_data = data + BLOCK_SIZE * cur_rank
26+
src_data = data + BLOCK_SIZE * cur_rank
2727
dest_data = results + BLOCK_SIZE * target_rank
28-
iris.copy(
29-
src_data + offsets,
30-
dest_data + offsets,
31-
cur_rank,
32-
target_rank,
33-
heap_bases,
34-
mask
35-
)
28+
iris.copy(src_data + offsets, dest_data + offsets, cur_rank, target_rank, heap_bases, mask)
3629

3730

3831
@pytest.mark.parametrize(
@@ -53,7 +46,7 @@ def copy_kernel(
5346
32,
5447
],
5548
)
56-
def test_copy_get_semantics(dtype, BLOCK_SIZE):
49+
def test_copy_get_semantics(dtype, BLOCK_SIZE):
5750
shmem = iris.iris(1 << 20)
5851
num_ranks = shmem.get_num_ranks()
5952
heap_bases = shmem.get_heap_bases()
@@ -66,24 +59,17 @@ def test_copy_get_semantics(dtype, BLOCK_SIZE):
6659

6760
results = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=dtype)
6861
grid = lambda meta: (1,)
69-
copy_kernel[grid](
70-
data,
71-
results,
72-
cur_rank,
73-
num_ranks,
74-
BLOCK_SIZE,
75-
heap_bases
76-
)
62+
copy_kernel[grid](data, results, cur_rank, num_ranks, BLOCK_SIZE, heap_bases)
7763
shmem.barrier()
7864

79-
expected = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=dtype)
65+
expected = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=dtype)
8066
for rank_id in range(num_ranks):
8167
expected[rank_id, :] = (rank_id + num_ranks) * (cur_rank + 1)
82-
68+
8369
try:
8470
torch.testing.assert_close(results, expected, rtol=0, atol=0)
8571
except AssertionError as e:
8672
print(e)
8773
print("Expected:", expected)
8874
print("Actual:", results)
89-
raise
75+
raise

0 commit comments

Comments
 (0)