Skip to content

Commit 504b3f4

Browse files
committed
Merge remote-tracking branch 'origin/main' into astroC86/get-or-put-to-copy
2 parents 1031a5d + 9e544cb commit 504b3f4

27 files changed

+2030
-171
lines changed

.github/workflows/auto-label.json5

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
{
2+
labelsSynonyms: {
3+
feature: ["[Feature]"],
4+
bug: ["[Issue]"],
5+
documentation: ["[Documentation]"]
6+
},
7+
defaultLabels: ["iris"], // always add iris
8+
includeTitle: true,
9+
ignoreComments: true
10+
}

.github/workflows/auto-label.yml

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
name: Auto labeling
2+
on:
3+
issues:
4+
types: [opened]
5+
pull_request:
6+
types: [opened]
7+
8+
permissions:
9+
contents: read
10+
issues: write
11+
pull-requests: write
12+
13+
jobs:
14+
# Label ISSUES using Renato66/auto-label
15+
label-issues:
16+
if: github.event_name == 'issues'
17+
runs-on: ubuntu-latest
18+
steps:
19+
- uses: actions/checkout@v4
20+
with:
21+
sparse-checkout: |
22+
.github/workflows/auto-label.json5
23+
sparse-checkout-cone-mode: false
24+
- uses: Renato66/auto-label@v3
25+
with:
26+
repo-token: ${{ secrets.GITHUB_TOKEN }}
27+
28+
# Add ISSUES to ROCm Project #91 so they land in Todo
29+
add-issues-to-project:
30+
if: github.event_name == 'issues'
31+
runs-on: ubuntu-latest
32+
steps:
33+
- uses: actions/[email protected]
34+
with:
35+
project-url: https://github.com/orgs/ROCm/projects/91
36+
github-token: ${{ secrets.ADD_TO_PROJECT_PAT }}
37+
38+
# PRs: label so the project rule moves them to In Progress
39+
label-prs:
40+
if: github.event_name == 'pull_request'
41+
runs-on: ubuntu-latest
42+
steps:
43+
- name: Add iris + in-progress labels to PR
44+
uses: actions/github-script@v7
45+
with:
46+
script: |
47+
await github.rest.issues.addLabels({
48+
owner: context.repo.owner,
49+
repo: context.repo.repo,
50+
issue_number: context.payload.pull_request.number,
51+
labels: ["iris", "in-progress"]
52+
});

apptainer/iris.def

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ From: rocm/pytorch:rocm6.3.1_ubuntu22.04_py3.10_pytorch
1313
conda install -y -n py_3.10 -c conda-forge mpi4py openmpi jupyter ninja cmake wheel
1414
git clone https://github.com/triton-lang/triton.git \$TRITON_PATH
1515
cd \$TRITON_PATH
16-
git checkout eb73b0373a7fb4cd2e563f68e3488a96525562eb
16+
git checkout dd5823453bcc7973eabadb65f9d827c43281c434
1717
pip install -e .
1818
wget https://github.com/ROCm/rocprofiler-systems/releases/download/rocm-6.3.1/rocprofiler-systems-install.py
1919
python3 ./rocprofiler-systems-install.py --prefix /opt/rocprofiler-systems --rocm 6.3

apptainer/jupyter.sh

Lines changed: 0 additions & 22 deletions
This file was deleted.

docker/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ RUN sudo pip3 install mpi4py
3737
# Clone and install Triton
3838
WORKDIR $TRITON_PATH
3939
RUN git clone https://github.com/triton-lang/triton.git $TRITON_PATH
40-
RUN git checkout eb73b0373a7fb4cd2e563f68e3488a96525562eb
40+
RUN git checkout dd5823453bcc7973eabadb65f9d827c43281c434
4141
RUN pip3 install -e .
4242
ENV PYTHONPATH=$TRITON_PATH
4343

examples/07_gemm_all_scatter/gemm_all_scatter.py

Lines changed: 38 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -11,89 +11,6 @@
1111
import iris
1212

1313

14-
@triton.jit
15-
def tile_id_to_index_range(
16-
tile_id,
17-
M,
18-
N,
19-
BLOCK_SIZE_M: tl.constexpr,
20-
BLOCK_SIZE_N: tl.constexpr,
21-
GROUP_SIZE_M: tl.constexpr,
22-
):
23-
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
24-
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
25-
num_pid_in_group = GROUP_SIZE_M * num_pid_n
26-
27-
group_id = tile_id // num_pid_in_group
28-
first_pid_m = group_id * GROUP_SIZE_M
29-
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
30-
31-
tile_in_group = tile_id % num_pid_in_group
32-
pid_m = first_pid_m + (tile_in_group % group_size_m)
33-
pid_n = tile_in_group // group_size_m
34-
35-
rm_start = pid_m * BLOCK_SIZE_M
36-
rn_start = pid_n * BLOCK_SIZE_N
37-
38-
# clamp to the maximum valid index (M-1, N-1)
39-
max_m = M - 1
40-
max_n = N - 1
41-
42-
# generate indices
43-
rm = rm_start + tl.arange(0, BLOCK_SIZE_M)
44-
rn = rn_start + tl.arange(0, BLOCK_SIZE_N)
45-
46-
rm = tl.minimum(rm, max_m)
47-
rn = tl.minimum(rn, max_n)
48-
49-
# rm_mod = rm % M
50-
# rm = tl.max_contiguous(tl.multiple_of(rm_mod, BLOCK_SIZE_M), BLOCK_SIZE_M)
51-
52-
return rm, rn, rm_start, rn_start
53-
54-
55-
@triton.jit
56-
def offset_for_tile(local_tile_id, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M, M_local, N_local):
57-
rm, rn, rm_start, rn_start = tile_id_to_index_range(
58-
local_tile_id, M_local, N_local, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M
59-
)
60-
c_mask = (rm[:, None] < M_local) & (rn[None, :] < N_local)
61-
return rm, rn, c_mask, rm_start, rn_start
62-
63-
64-
@triton.jit
65-
def extract_submask_and_offset(
66-
rm,
67-
rn,
68-
mask,
69-
rm_start,
70-
rn_start,
71-
start_row,
72-
start_col,
73-
SUB_BLOCK_SIZE_M: tl.constexpr,
74-
SUB_BLOCK_SIZE_N: tl.constexpr,
75-
BLOCK_SIZE_M: tl.constexpr,
76-
BLOCK_SIZE_N: tl.constexpr,
77-
stride_cm_local: tl.constexpr,
78-
stride_cn_local: tl.constexpr,
79-
):
80-
# Create indices for the sub-block
81-
sub_rm = tl.arange(0, SUB_BLOCK_SIZE_M) + start_row
82-
sub_rn = tl.arange(0, SUB_BLOCK_SIZE_N) + start_col
83-
84-
# Create a 2D grid of indices for the sub-block
85-
sub_rm_2d = sub_rm[:, None] # Shape: (SUB_BLOCK_SIZE_M, 1)
86-
sub_rn_2d = sub_rn[None, :] # Shape: (1, SUB_BLOCK_SIZE_N)
87-
88-
# Compute the sub-mask
89-
sub_mask = (sub_rm_2d < BLOCK_SIZE_M) & (sub_rn_2d < BLOCK_SIZE_N)
90-
91-
# Compute the sub-offset relative to the start of the tile
92-
sub_offset = ((rm_start + sub_rm_2d) * stride_cm_local) + ((rn_start + sub_rn_2d) * stride_cn_local)
93-
94-
return sub_mask, sub_offset
95-
96-
9714
@triton.jit()
9815
def persistent_gemm_all_scatter(
9916
A,
@@ -166,8 +83,8 @@ def persistent_gemm_all_scatter(
16683
A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak
16784
B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn
16885

169-
tl.assume(pid_m > 0)
170-
tl.assume(pid_n > 0)
86+
tl.assume(pid_m >= 0)
87+
tl.assume(pid_n >= 0)
17188

17289
loop_k = tl.cdiv(K, BLOCK_SIZE_K)
17390
if not EVEN_K:
@@ -195,51 +112,39 @@ def persistent_gemm_all_scatter(
195112
# Accumulator registers with C results
196113
c = acc.to(C.type.element_ty)
197114

198-
rm, rn, mask, rm_start, rn_start = offset_for_tile(tile_id, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M, M, N)
199-
200-
# Calculate the number of sub-tiles in each dimension
201-
num_sub_tiles_m = tl.cdiv(BLOCK_SIZE_M, BLOCK_SIZE_M)
202-
num_sub_tiles_n = tl.cdiv(BLOCK_SIZE_N, BLOCK_SIZE_N)
203-
total_sub_tiles = num_sub_tiles_m * num_sub_tiles_n
204-
205-
for sub_tile_idx in range(0, total_sub_tiles):
206-
# Calculate start_row and start_col for the current sub-tile
207-
start_row = (sub_tile_idx // num_sub_tiles_n) * BLOCK_SIZE_M
208-
start_col = (sub_tile_idx % num_sub_tiles_n) * BLOCK_SIZE_N
209-
210-
# Translate to global
211-
sub_mask, global_offset = extract_submask_and_offset(
212-
rm,
213-
rn + cur_rank * N,
214-
mask,
215-
rm_start,
216-
rn_start + cur_rank * N,
217-
start_row,
218-
start_col,
219-
BLOCK_SIZE_M,
220-
BLOCK_SIZE_N,
221-
BLOCK_SIZE_M,
222-
BLOCK_SIZE_N,
223-
stride_cm_global,
224-
stride_cn_global,
225-
)
226-
227-
# Timestamp for GEMM before store
228-
if COLLECT_TIMESTAMPS:
229-
timestamp = read_realtime()
230-
tl.atomic_max(mm_end_timestamp_ptr + tile_id, timestamp)
231-
232-
# Store data to the global result using puts
233-
for remote_rank in range(world_size):
234-
if remote_rank == cur_rank:
235-
# For the current rank, we can use store
236-
tl.store(c_global + global_offset, c, mask=sub_mask)
237-
else:
238-
iris.store(
239-
c_global + global_offset,
240-
c,
241-
cur_rank,
242-
remote_rank,
243-
heap_bases,
244-
mask=sub_mask,
245-
)
115+
rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
116+
rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
117+
118+
# Add compiler hints
119+
rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M)
120+
rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N)
121+
122+
# Define the C-mask (BLOCK_SIZE_M, 1) x (1, BLOCK_SIZE_N)
123+
sub_mask = (rm[:, None] < M) & (rn[None, :] < N)
124+
125+
# Calculate the "global" offset of C based on the rank.
126+
# Note how the N-dimension is being multiplied by current rank.
127+
# This is because each rank is computing a portion of the N-dimension
128+
# locally and then scattering it to all other ranks to complete
129+
# the global N-dimension.
130+
global_offset = rm[:, None] * stride_cm_global + (rn[None, :] + cur_rank * N) * stride_cn_global
131+
132+
# Timestamp for GEMM before store
133+
if COLLECT_TIMESTAMPS:
134+
timestamp = read_realtime()
135+
tl.atomic_max(mm_end_timestamp_ptr + tile_id, timestamp)
136+
137+
# Store data to the global result using puts
138+
for remote_rank in range(world_size):
139+
if remote_rank == cur_rank:
140+
# For the current rank, we can use store
141+
tl.store(c_global + global_offset, c, mask=sub_mask)
142+
else:
143+
iris.store(
144+
c_global + global_offset,
145+
c,
146+
cur_rank,
147+
remote_rank,
148+
heap_bases,
149+
mask=sub_mask,
150+
)

0 commit comments

Comments
 (0)