65
65
}
66
66
67
67
68
- def get_specs ():
69
- gpu_name = torch .cuda .get_device_name (0 )
68
+ def get_specs (gpu_name : Optional [str ] = None ):
69
+ if gpu_name is None :
70
+ gpu_name = torch .cuda .get_device_name (0 )
70
71
return gpu_name_to_specs [gpu_name ]
71
72
72
73
@@ -214,10 +215,14 @@ def get_tensor_memory_traffic_ovhd_s(
214
215
215
216
216
217
def get_individual_gemm_time_sympy (
217
- M : sympy .Symbol , K : sympy .Symbol , N : sympy .Symbol , dtype , mx_recipe_name
218
+ M : sympy .Symbol ,
219
+ K : sympy .Symbol ,
220
+ N : sympy .Symbol ,
221
+ dtype , mx_recipe_name ,
222
+ gpu_name : Optional [str ] = None ,
218
223
) -> sympy .Symbol :
219
224
# compute bound
220
- specs = get_specs ()
225
+ specs = get_specs (gpu_name )
221
226
gemm_ops = 2 * M * K * N
222
227
if dtype is torch .bfloat16 :
223
228
peak_tops = specs ["bf16_peak_tops" ]
@@ -265,6 +270,7 @@ def get_gemm_time_sympy(
265
270
dtype ,
266
271
float8_recipe_name : Optional [str ],
267
272
mx_recipe_name : Optional [str ],
273
+ gpu_name : Optional [str ],
268
274
):
269
275
# next: add rowwise_with_gw_hp here
270
276
# note: this function is currently not super accurate for small shapes:
@@ -279,13 +285,13 @@ def get_gemm_time_sympy(
279
285
gemm_dtype_grad_weight = torch .bfloat16
280
286
281
287
gemm_output_time_s = get_individual_gemm_time_sympy (
282
- M , K , N , gemm_dtype_input , mx_recipe_name
288
+ M , K , N , gemm_dtype_input , mx_recipe_name , gpu_name
283
289
)
284
290
gemm_grad_input_time_s = get_individual_gemm_time_sympy (
285
- M , N , K , gemm_dtype_grad_input , mx_recipe_name
291
+ M , N , K , gemm_dtype_grad_input , mx_recipe_name , gpu_name
286
292
)
287
293
gemm_grad_weight_time_s = get_individual_gemm_time_sympy (
288
- K , M , N , gemm_dtype_grad_weight , mx_recipe_name
294
+ K , M , N , gemm_dtype_grad_weight , mx_recipe_name , gpu_name
289
295
)
290
296
total = gemm_output_time_s + gemm_grad_input_time_s + gemm_grad_weight_time_s
291
297
return total
@@ -298,8 +304,9 @@ def get_float8_mem_sympy(
298
304
float8_recipe_name : Optional [str ],
299
305
mx_recipe_name : Optional [str ],
300
306
enable_fusion_modeling : bool ,
307
+ gpu_name : Optional [str ] = None ,
301
308
):
302
- specs = get_specs ()
309
+ specs = get_specs (gpu_name )
303
310
304
311
# there are three gemms in the fwd/bwd of a linear:
305
312
#
0 commit comments