18
18
# isort: off
19
19
import torch
20
20
# isort: on
21
- from cuda import cuda , cudart
21
+ from cuda import cudart
22
22
23
23
import tensorrt_llm as tllm
24
- from tensorrt_llm import Mapping , Tensor
24
+ from tensorrt_llm import Mapping
25
+ from tensorrt_llm ._torch .distributed import AllReduce , AllReduceFusionOp
26
+ from tensorrt_llm ._torch .modules .rms_norm import RMSNorm
25
27
from tensorrt_llm ._utils import local_mpi_rank , local_mpi_size
26
- from tensorrt_llm .functional import (AllReduceParams , AllReduceStrategy ,
27
- allreduce )
28
- from tensorrt_llm .plugin .plugin import (current_all_reduce_helper ,
29
- init_all_reduce_helper )
30
- from tensorrt_llm .runtime import Session
28
+ from tensorrt_llm .bindings .internal .runtime import delay_kernel
29
+ from tensorrt_llm .functional import AllReduceParams , AllReduceStrategy
31
30
32
31
33
32
def allreduce_benchmark (dtype : str ,
34
- test_range : str = "10,10000000,10" ,
35
- no_header : bool = False ):
33
+ test_range : str = "1,10000000,10" ,
34
+ no_header : bool = False ,
35
+ enable_cudagraph : bool = False ):
36
36
tllm .logger .set_level ('error' )
37
37
world_size = tllm .mpi_world_size ()
38
38
rank = tllm .mpi_rank ()
@@ -49,80 +49,120 @@ def allreduce_benchmark(dtype: str,
49
49
50
50
torch_dtype = tllm ._utils .str_dtype_to_torch (dtype )
51
51
min_size , max_size , ratio = [int (i ) for i in test_range .split ("," )]
52
- inner_loop = 1000
52
+ inner_loop = 1200
53
+ outer_loop = 10
53
54
54
55
size = min_size
55
- dtype_size = torch .finfo (torch_dtype ).bits // 8
56
+ hidden_size = size
57
+ bs = 1
56
58
if mapping .rank == 0 and not no_header :
57
59
print (
58
- f"{ 'world_size' :<15} , { 'dtype' :<10} , { 'message size' :<15} , { 'strategy' :<15 } , { 'duration (ms)' :<10} "
60
+ f"{ 'world_size' :<15} , { 'dtype' :<10} , { 'message size' :<15} , { 'strategy' :<10 } , { 'fusion' :<20 } , { 'version' :<10 } , { 'duration (ms)' :<10} "
59
61
)
60
62
while size < max_size :
61
- input = torch .ones (size , dtype = torch_dtype , device = "cuda" )
62
-
63
- for strategy in [
64
- AllReduceStrategy .AUTO ,
65
- AllReduceStrategy .NCCL ,
66
- AllReduceStrategy .ONESHOT ,
67
- AllReduceStrategy .TWOSHOT ,
68
- ]:
69
- builder = tllm .Builder ()
70
- net = builder .create_network ()
71
- net .plugin_config .set_nccl_plugin (dtype )
72
- init_all_reduce_helper ()
73
- _buffers , workspace = current_all_reduce_helper (
74
- ).allocate_workspace (mapping , size * dtype_size )
75
-
76
- with tllm .net_guard (net ):
77
- tllm .default_trtnet ()
78
-
79
- x = Tensor (name = 'x' ,
80
- shape = input .shape ,
81
- dtype = tllm .str_dtype_to_trt (dtype ))
82
-
83
- current_all_reduce_helper ().set_workspace_tensor (mapping )
84
-
85
- current = x
86
- for _ in range (inner_loop ):
87
- current = allreduce (
88
- current ,
89
- mapping .tp_group ,
90
- all_reduce_params = AllReduceParams (strategy = strategy ))
91
- current .mark_output ('output' , dtype )
92
- feed_dict = {'x' : input , 'all_reduce_workspace' : workspace }
93
- builder_config = builder .create_builder_config (precision = dtype )
94
- engine = builder .build_engine (net , builder_config )
95
- assert engine is not None , "Failed to build engine"
96
- session = Session .from_serialized_engine (engine )
97
-
98
- _ , start = cuda .cuEventCreate (0 )
99
- _ , stop = cuda .cuEventCreate (0 )
100
- runtimes = []
101
-
102
- tllm .mpi_barrier ()
103
- output = torch .empty (input .shape , dtype = torch_dtype , device = 'cuda' )
104
- stream = torch .cuda .current_stream ()
105
- for _ in range (10 ):
106
- cuda .cuEventRecord (start , stream .cuda_stream )
107
- session .run (inputs = feed_dict ,
108
- outputs = {"output" : output },
109
- stream = stream .cuda_stream )
110
- cuda .cuEventRecord (stop , stream .cuda_stream )
111
- torch .cuda .synchronize ()
112
- _ , ms = cuda .cuEventElapsedTime (start , stop )
113
- runtimes .append (ms )
114
-
115
- median_ms = sorted (runtimes )[len (runtimes ) // 2 ]
116
-
117
- allreduce_ref = (input * world_size )** inner_loop
118
- assert torch .allclose (output , allreduce_ref )
119
-
120
- if mapping .rank == 0 :
121
- print (
122
- f"{ mapping .world_size :<15} , { dtype :<10} , { size :<15} , { strategy .name :<15} , { median_ms :<10.2f} "
123
- )
63
+ input = torch .ones ((bs , hidden_size ), dtype = torch_dtype , device = "cuda" )
64
+
65
+ for version in ["v1" ]:
66
+ for fusion in [
67
+ AllReduceFusionOp .RESIDUAL_RMS_NORM , AllReduceFusionOp .NONE
68
+ ]:
69
+ for strategy in [
70
+ AllReduceStrategy .NCCL ,
71
+ AllReduceStrategy .ONESHOT ,
72
+ AllReduceStrategy .TWOSHOT ,
73
+ ]:
74
+ if size >= 25600000 and fusion != AllReduceFusionOp .NONE :
75
+ continue
76
+ allreduce = AllReduce (mapping = mapping , strategy = strategy )
77
+ if fusion == AllReduceFusionOp .RESIDUAL_RMS_NORM :
78
+ norm_weight = torch .randn ((hidden_size , ),
79
+ dtype = torch_dtype ,
80
+ device = "cuda" )
81
+ norm = RMSNorm (hidden_size = hidden_size ,
82
+ dtype = torch_dtype ,
83
+ eps = 1e-5 ).cuda ()
84
+ norm .weight .data .copy_ (norm_weight )
85
+ if version == "v1" :
86
+ params = {
87
+ "all_reduce_params" :
88
+ AllReduceParams (fusion_op = fusion ,
89
+ residual = input ,
90
+ norm_weight = norm .weight ,
91
+ eps = norm .variance_epsilon )
92
+ }
93
+ else :
94
+ params = {
95
+ "reduce_fusion_inputs" : [input , norm .weight ],
96
+ "eps" : norm .variance_epsilon ,
97
+ "fusion_op" : fusion
98
+ }
99
+ else :
100
+ if version == "v1" :
101
+ params = {
102
+ "all_reduce_params" :
103
+ AllReduceParams (fusion_op = fusion )
104
+ }
105
+ else :
106
+ continue
107
+
108
+ def func (input ):
109
+ for _ in range (inner_loop ):
110
+ input = allreduce (input , ** params )
111
+ if fusion == AllReduceFusionOp .RESIDUAL_RMS_NORM :
112
+ input = input [0 ]
113
+ return input
114
+
115
+ start = [
116
+ torch .cuda .Event (enable_timing = True )
117
+ for _ in range (outer_loop )
118
+ ]
119
+ stop = [
120
+ torch .cuda .Event (enable_timing = True )
121
+ for _ in range (outer_loop )
122
+ ]
123
+ graph = torch .cuda .CUDAGraph ()
124
+
125
+ stream = torch .cuda .Stream ()
126
+ with torch .cuda .stream (stream ):
127
+ if enable_cudagraph :
128
+ for _ in range (2 ):
129
+ func (input )
130
+ with torch .cuda .graph (graph , stream = stream ):
131
+ output = func (input )
132
+ tllm .mpi_barrier ()
133
+ delay_kernel (2000000 , stream )
134
+ torch .cuda .profiler .start ()
135
+ for i in range (outer_loop ):
136
+ start [i ].record (stream )
137
+ if enable_cudagraph :
138
+ graph .replay ()
139
+ else :
140
+ output = func (input )
141
+ stop [i ].record (stream )
142
+
143
+ torch .cuda .synchronize ()
144
+ torch .cuda .profiler .stop ()
145
+ runtimes = [
146
+ start [i ].elapsed_time (stop [i ])
147
+ for i in range (outer_loop )
148
+ ]
149
+ median_ms = sorted (runtimes )[len (runtimes ) // 2 ]
150
+
151
+ if fusion == AllReduceFusionOp .NONE :
152
+ allreduce_ref = (input * world_size )** inner_loop
153
+ torch .testing .assert_close (output , allreduce_ref )
154
+
155
+ if mapping .rank == 0 :
156
+ print (
157
+ f"{ mapping .world_size :<15} , { dtype :<10} , { size :<15} , { strategy .name :<10} , { fusion .name :<20} , { version :<10} , { median_ms :<10.2f} "
158
+ )
124
159
125
160
size *= ratio
161
+ if hidden_size * ratio > 4096 :
162
+ bs *= ratio
163
+ else :
164
+ hidden_size *= ratio
165
+ assert size == bs * hidden_size
126
166
127
167
128
168
if __name__ == "__main__" :
@@ -134,6 +174,8 @@ def allreduce_benchmark(dtype: str,
134
174
default = "256,256000000,10" , # 256 to 256M
135
175
help = "min_size,max_size,multiplicative_ratio" )
136
176
parser .add_argument ("--no-header" , action = "store_true" )
177
+ parser .add_argument ("--enable-cudagraph" , action = "store_true" )
137
178
args = parser .parse_args ()
138
179
139
- allreduce_benchmark (args .dtype , args .range , args .no_header )
180
+ allreduce_benchmark (args .dtype , args .range , args .no_header ,
181
+ args .enable_cudagraph )
0 commit comments