Skip to content

Commit 43bd861

Browse files
authored
Update allreduce benchmark for torch (#6271)
Signed-off-by: Iman Tabrizian <[email protected]>
1 parent 83ee91e commit 43bd861

File tree

1 file changed

+118
-76
lines changed

1 file changed

+118
-76
lines changed

tests/microbenchmarks/all_reduce.py

Lines changed: 118 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,21 @@
1818
# isort: off
1919
import torch
2020
# isort: on
21-
from cuda import cuda, cudart
21+
from cuda import cudart
2222

2323
import tensorrt_llm as tllm
24-
from tensorrt_llm import Mapping, Tensor
24+
from tensorrt_llm import Mapping
25+
from tensorrt_llm._torch.distributed import AllReduce, AllReduceFusionOp
26+
from tensorrt_llm._torch.modules.rms_norm import RMSNorm
2527
from tensorrt_llm._utils import local_mpi_rank, local_mpi_size
26-
from tensorrt_llm.functional import (AllReduceParams, AllReduceStrategy,
27-
allreduce)
28-
from tensorrt_llm.plugin.plugin import (current_all_reduce_helper,
29-
init_all_reduce_helper)
30-
from tensorrt_llm.runtime import Session
28+
from tensorrt_llm.bindings.internal.runtime import delay_kernel
29+
from tensorrt_llm.functional import AllReduceParams, AllReduceStrategy
3130

3231

3332
def allreduce_benchmark(dtype: str,
34-
test_range: str = "10,10000000,10",
35-
no_header: bool = False):
33+
test_range: str = "1,10000000,10",
34+
no_header: bool = False,
35+
enable_cudagraph: bool = False):
3636
tllm.logger.set_level('error')
3737
world_size = tllm.mpi_world_size()
3838
rank = tllm.mpi_rank()
@@ -49,80 +49,120 @@ def allreduce_benchmark(dtype: str,
4949

5050
torch_dtype = tllm._utils.str_dtype_to_torch(dtype)
5151
min_size, max_size, ratio = [int(i) for i in test_range.split(",")]
52-
inner_loop = 1000
52+
inner_loop = 1200
53+
outer_loop = 10
5354

5455
size = min_size
55-
dtype_size = torch.finfo(torch_dtype).bits // 8
56+
hidden_size = size
57+
bs = 1
5658
if mapping.rank == 0 and not no_header:
5759
print(
58-
f"{'world_size':<15}, {'dtype':<10}, {'message size':<15}, {'strategy':<15}, {'duration (ms)':<10}"
60+
f"{'world_size':<15}, {'dtype':<10}, {'message size':<15}, {'strategy':<10}, {'fusion':<20}, {'version':<10}, {'duration (ms)':<10}"
5961
)
6062
while size < max_size:
61-
input = torch.ones(size, dtype=torch_dtype, device="cuda")
62-
63-
for strategy in [
64-
AllReduceStrategy.AUTO,
65-
AllReduceStrategy.NCCL,
66-
AllReduceStrategy.ONESHOT,
67-
AllReduceStrategy.TWOSHOT,
68-
]:
69-
builder = tllm.Builder()
70-
net = builder.create_network()
71-
net.plugin_config.set_nccl_plugin(dtype)
72-
init_all_reduce_helper()
73-
_buffers, workspace = current_all_reduce_helper(
74-
).allocate_workspace(mapping, size * dtype_size)
75-
76-
with tllm.net_guard(net):
77-
tllm.default_trtnet()
78-
79-
x = Tensor(name='x',
80-
shape=input.shape,
81-
dtype=tllm.str_dtype_to_trt(dtype))
82-
83-
current_all_reduce_helper().set_workspace_tensor(mapping)
84-
85-
current = x
86-
for _ in range(inner_loop):
87-
current = allreduce(
88-
current,
89-
mapping.tp_group,
90-
all_reduce_params=AllReduceParams(strategy=strategy))
91-
current.mark_output('output', dtype)
92-
feed_dict = {'x': input, 'all_reduce_workspace': workspace}
93-
builder_config = builder.create_builder_config(precision=dtype)
94-
engine = builder.build_engine(net, builder_config)
95-
assert engine is not None, "Failed to build engine"
96-
session = Session.from_serialized_engine(engine)
97-
98-
_, start = cuda.cuEventCreate(0)
99-
_, stop = cuda.cuEventCreate(0)
100-
runtimes = []
101-
102-
tllm.mpi_barrier()
103-
output = torch.empty(input.shape, dtype=torch_dtype, device='cuda')
104-
stream = torch.cuda.current_stream()
105-
for _ in range(10):
106-
cuda.cuEventRecord(start, stream.cuda_stream)
107-
session.run(inputs=feed_dict,
108-
outputs={"output": output},
109-
stream=stream.cuda_stream)
110-
cuda.cuEventRecord(stop, stream.cuda_stream)
111-
torch.cuda.synchronize()
112-
_, ms = cuda.cuEventElapsedTime(start, stop)
113-
runtimes.append(ms)
114-
115-
median_ms = sorted(runtimes)[len(runtimes) // 2]
116-
117-
allreduce_ref = (input * world_size)**inner_loop
118-
assert torch.allclose(output, allreduce_ref)
119-
120-
if mapping.rank == 0:
121-
print(
122-
f"{mapping.world_size:<15}, {dtype:<10}, {size:<15}, {strategy.name:<15}, {median_ms:<10.2f}"
123-
)
63+
input = torch.ones((bs, hidden_size), dtype=torch_dtype, device="cuda")
64+
65+
for version in ["v1"]:
66+
for fusion in [
67+
AllReduceFusionOp.RESIDUAL_RMS_NORM, AllReduceFusionOp.NONE
68+
]:
69+
for strategy in [
70+
AllReduceStrategy.NCCL,
71+
AllReduceStrategy.ONESHOT,
72+
AllReduceStrategy.TWOSHOT,
73+
]:
74+
if size >= 25600000 and fusion != AllReduceFusionOp.NONE:
75+
continue
76+
allreduce = AllReduce(mapping=mapping, strategy=strategy)
77+
if fusion == AllReduceFusionOp.RESIDUAL_RMS_NORM:
78+
norm_weight = torch.randn((hidden_size, ),
79+
dtype=torch_dtype,
80+
device="cuda")
81+
norm = RMSNorm(hidden_size=hidden_size,
82+
dtype=torch_dtype,
83+
eps=1e-5).cuda()
84+
norm.weight.data.copy_(norm_weight)
85+
if version == "v1":
86+
params = {
87+
"all_reduce_params":
88+
AllReduceParams(fusion_op=fusion,
89+
residual=input,
90+
norm_weight=norm.weight,
91+
eps=norm.variance_epsilon)
92+
}
93+
else:
94+
params = {
95+
"reduce_fusion_inputs": [input, norm.weight],
96+
"eps": norm.variance_epsilon,
97+
"fusion_op": fusion
98+
}
99+
else:
100+
if version == "v1":
101+
params = {
102+
"all_reduce_params":
103+
AllReduceParams(fusion_op=fusion)
104+
}
105+
else:
106+
continue
107+
108+
def func(input):
109+
for _ in range(inner_loop):
110+
input = allreduce(input, **params)
111+
if fusion == AllReduceFusionOp.RESIDUAL_RMS_NORM:
112+
input = input[0]
113+
return input
114+
115+
start = [
116+
torch.cuda.Event(enable_timing=True)
117+
for _ in range(outer_loop)
118+
]
119+
stop = [
120+
torch.cuda.Event(enable_timing=True)
121+
for _ in range(outer_loop)
122+
]
123+
graph = torch.cuda.CUDAGraph()
124+
125+
stream = torch.cuda.Stream()
126+
with torch.cuda.stream(stream):
127+
if enable_cudagraph:
128+
for _ in range(2):
129+
func(input)
130+
with torch.cuda.graph(graph, stream=stream):
131+
output = func(input)
132+
tllm.mpi_barrier()
133+
delay_kernel(2000000, stream)
134+
torch.cuda.profiler.start()
135+
for i in range(outer_loop):
136+
start[i].record(stream)
137+
if enable_cudagraph:
138+
graph.replay()
139+
else:
140+
output = func(input)
141+
stop[i].record(stream)
142+
143+
torch.cuda.synchronize()
144+
torch.cuda.profiler.stop()
145+
runtimes = [
146+
start[i].elapsed_time(stop[i])
147+
for i in range(outer_loop)
148+
]
149+
median_ms = sorted(runtimes)[len(runtimes) // 2]
150+
151+
if fusion == AllReduceFusionOp.NONE:
152+
allreduce_ref = (input * world_size)**inner_loop
153+
torch.testing.assert_close(output, allreduce_ref)
154+
155+
if mapping.rank == 0:
156+
print(
157+
f"{mapping.world_size:<15}, {dtype:<10}, {size:<15}, {strategy.name:<10}, {fusion.name:<20}, {version:<10}, {median_ms:<10.2f}"
158+
)
124159

125160
size *= ratio
161+
if hidden_size * ratio > 4096:
162+
bs *= ratio
163+
else:
164+
hidden_size *= ratio
165+
assert size == bs * hidden_size
126166

127167

128168
if __name__ == "__main__":
@@ -134,6 +174,8 @@ def allreduce_benchmark(dtype: str,
134174
default="256,256000000,10", # 256 to 256M
135175
help="min_size,max_size,multiplicative_ratio")
136176
parser.add_argument("--no-header", action="store_true")
177+
parser.add_argument("--enable-cudagraph", action="store_true")
137178
args = parser.parse_args()
138179

139-
allreduce_benchmark(args.dtype, args.range, args.no_header)
180+
allreduce_benchmark(args.dtype, args.range, args.no_header,
181+
args.enable_cudagraph)

0 commit comments

Comments
 (0)