Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 112 additions & 76 deletions examples/09_gemm_one_shot_all_reduce/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,106 @@ def parse_args():
return vars(parser.parse_args())


def gemm_one_shot_all_reduce(A, B, shmem, args_dict):
"""
Core GEMM one-shot all-reduce function that can be reused by both example and tests.

Args:
A: Input matrix A
B: Input matrix B
shmem: Iris shared memory object
args_dict: Dictionary containing algorithm parameters

Returns:
global_C: The result matrix after GEMM and all-reduce
"""
rank = shmem.get_rank()
world_size = shmem.get_num_ranks()
cu_count = shmem.get_cu_count()

# Validate divisibility requirements
assert args_dict["n"] % world_size == 0, f"N ({args_dict['n']}) must be divisible by world size ({world_size})."
assert args_dict["k"] % world_size == 0, f"K ({args_dict['k']}) must be divisible by world size ({world_size})."

# Splitting
rows_per_gpu = args_dict["k"] // world_size
start_row = rank * rows_per_gpu
end_row = start_row + rows_per_gpu
local_B = B[start_row:end_row, :]
local_A = A[:, start_row:end_row]

# Create output tensors
global_C = shmem.zeros((args_dict["m"], args_dict["n"]), device="cuda", dtype=A.dtype)
local_C = shmem.zeros((args_dict["m"], args_dict["n"]), device="cuda", dtype=A.dtype)

# Calculate tile information
total_blocks_M = triton.cdiv(args_dict["m"], args_dict["BLK_M"])
total_blocks_N = triton.cdiv(args_dict["n"], args_dict["BLK_N"])
total_tiles = total_blocks_M * total_blocks_N

if args_dict["gemm_sms"] >= args_dict["total_sms"]:
raise ValueError(f"Invalid number of stream-K SMs. {args_dict['gemm_sms']} >= {args_dict['total_sms']}")

# Create synchronization tensors
tile_completed = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32)
locks = shmem.zeros((args_dict["gemm_sms"],), device="cuda", dtype=torch.int32)
P = shmem.zeros(
(args_dict["gemm_sms"], args_dict["BLK_M"] * args_dict["BLK_N"]),
device="cuda",
dtype=torch.float32,
)
bias = None

# Timestamps for tracing (optional)
timestamps = Timestamps(num_tiles=total_tiles)

def preamble():
shmem.barrier()
iris.memset_tensor(tile_completed, 0)
shmem.barrier()

# Prepare for computation
shmem.barrier()
preamble()
shmem.barrier()

# Run the GEMM + all-reduce
shmem.barrier()

local_C = matmul.apply(
local_A,
local_B,
local_C,
global_C,
bias,
P,
locks,
tile_completed,
rank,
world_size,
args_dict["gemm_sms"],
args_dict["BLK_M"],
args_dict["BLK_N"],
args_dict["BLK_K"],
args_dict["gsize_m"],
args_dict["two_tiles"],
args_dict["num_stages"],
args_dict["num_warps"],
args_dict["waves_per_eu"],
args_dict["mfmaInstrSize"],
args_dict["kpack"],
shmem.get_heap_bases(),
cu_count,
args_dict.get("trace_tiles", False),
timestamps.mm_begin_timestamp,
timestamps.mm_end_timestamp,
)

shmem.barrier()

return global_C


def main():
args = parse_args()

Expand All @@ -95,9 +195,6 @@ def main():
print("Unknown datatype.")
exit(1)

assert args["n"] % world_size == 0, f"N ({args['n']}) must be divisible by world size ({world_size})."
assert args["k"] % world_size == 0, f"K ({args['k']}) must be divisible by world size ({world_size})."

A = shmem.randn(args["m"], args["k"], device="cuda", dtype=datatype)
B = shmem.randn(args["n"], args["k"], device="cuda", dtype=datatype).T
C = shmem.zeros((args["m"], args["n"]), device="cuda", dtype=A.dtype)
Expand All @@ -109,41 +206,9 @@ def main():
json_writer = JSONWriter(args["output_file"])
json_writer.add_field("world_size", world_size)

# Splitting
rows_per_gpu = args["k"] // world_size
args["k"] = rows_per_gpu
start_row = rank * rows_per_gpu
end_row = start_row + rows_per_gpu
local_B = B[start_row:end_row, :]
local_A = A[:, start_row:end_row]

for key, value in args.items():
json_writer.add_field(key, value)

global_C = shmem.zeros((args["M"], args["N"]), device="cuda", dtype=A.dtype)
local_C = shmem.zeros((args["m"], args["n"]), device="cuda", dtype=A.dtype)

total_blocks_M = triton.cdiv(args["m"], args["BLK_M"])
total_blocks_N = triton.cdiv(args["n"], args["BLK_N"])
total_tiles = total_blocks_M * total_blocks_N

if args["gemm_sms"] >= args["total_sms"]:
print(f"Invalid number of stream-K SMs. {args['gemm_sms']} >= {args['total_sms']}")
exit(1)

tile_completed = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32)

locks = shmem.zeros((args["gemm_sms"],), device="cuda", dtype=torch.int32)

P = shmem.zeros(
(args["gemm_sms"], args["BLK_M"] * args["BLK_N"]),
device="cuda",
dtype=torch.float32,
)
bias = None

gemm_stream = torch.cuda.Stream()

json_writer.add_field("gemm_sms", args["gemm_sms"])

kernel_timing = {
Expand All @@ -156,55 +221,22 @@ def main():
}

# Timestamps
total_blocks_M = triton.cdiv(args["m"], args["BLK_M"])
total_blocks_N = triton.cdiv(args["n"], args["BLK_N"])
total_tiles = total_blocks_M * total_blocks_N
timestamps = Timestamps(num_tiles=total_tiles)

def preamble():
shmem.barrier()
iris.memset_tensor(tile_completed, 0)
shmem.barrier()

def run_experiment():
nonlocal local_C
nonlocal global_C
nonlocal kernel_timing

shmem.barrier()

if args["trace_tiles"]:
timestamps.reset()
shmem.barrier()

torch.cuda.nvtx.range_push("GEMM + Communication")
with torch.cuda.stream(gemm_stream):
with torch.cuda.stream(torch.cuda.Stream()):
kernel_timing["gemm"]["start_event"].record()
local_C = matmul.apply(
local_A,
local_B,
local_C,
global_C,
bias,
P,
locks,
tile_completed,
rank,
world_size,
args["gemm_sms"],
args["BLK_M"],
args["BLK_N"],
args["BLK_K"],
args["gsize_m"],
args["two_tiles"],
args["num_stages"],
args["num_warps"],
args["waves_per_eu"],
args["mfmaInstrSize"],
args["kpack"],
shmem.get_heap_bases(),
cu_count,
args["trace_tiles"],
timestamps.mm_begin_timestamp,
timestamps.mm_end_timestamp,
)
global_C = gemm_one_shot_all_reduce(A, B, shmem, args)
kernel_timing["gemm"]["end_event"].record()
kernel_timing["gemm"]["experiments"] += 1

Expand All @@ -214,15 +246,15 @@ def run_experiment():
for k in ["gemm"]:
ms = kernel_timing[k]["start_event"].elapsed_time(kernel_timing[k]["end_event"])
kernel_timing[k]["ms"] += ms

return global_C

# Synchronize across all GPUs
shmem.barrier()

# Warmup
run_experiment()
global_C = run_experiment()

shmem.barrier()
preamble()
shmem.barrier()

for k in ["gemm"]:
Expand Down Expand Up @@ -253,6 +285,10 @@ def run_experiment():
if args["benchmark"]:
shmem.info("Benchmarking...")
perf = lambda ms: 2 * args["M"] * args["N"] * args["K"] * 1e-12 / (ms * 1e-3)

def preamble():
shmem.barrier()

triton_ms = iris.do_bench(run_experiment, shmem.barrier, preamble)
triton_tflops = perf(triton_ms)
shmem.info(f"tile matmul + all_reduce (grid={total_tiles}): {triton_ms:.3f} ms {triton_tflops:.3f} tflops")
Expand Down
Loading
Loading