65
65
}
66
66
67
67
68
- def get_specs ():
69
- gpu_name = torch .cuda .get_device_name (0 )
68
+ def get_specs (gpu_name : Optional [str ]= None ):
69
+ if gpu_name is None :
70
+ gpu_name = torch .cuda .get_device_name (0 )
70
71
return gpu_name_to_specs [gpu_name ]
71
72
72
73
@@ -213,10 +214,10 @@ def get_tensor_memory_traffic_ovhd_s(
213
214
214
215
215
216
def get_individual_gemm_time_sympy (
216
- M : sympy .Symbol , K : sympy .Symbol , N : sympy .Symbol , dtype , mx_recipe_name
217
+ M : sympy .Symbol , K : sympy .Symbol , N : sympy .Symbol , dtype , mx_recipe_name , gpu_name : Optional [ str ] = None
217
218
) -> sympy .Symbol :
218
219
# compute bound
219
- specs = get_specs ()
220
+ specs = get_specs (gpu_name )
220
221
gemm_ops = 2 * M * K * N
221
222
if dtype is torch .bfloat16 :
222
223
peak_tops = specs ["bf16_peak_tops" ]
@@ -263,6 +264,7 @@ def get_gemm_time_sympy(
263
264
dtype ,
264
265
float8_recipe_name : Optional [str ],
265
266
mx_recipe_name : Optional [str ],
267
+ gpu_name : Optional [str ],
266
268
):
267
269
# next: add rowwise_with_gw_hp here
268
270
# note: this function is currently not super accurate for small shapes:
@@ -277,13 +279,13 @@ def get_gemm_time_sympy(
277
279
gemm_dtype_grad_weight = torch .bfloat16
278
280
279
281
gemm_output_time_s = get_individual_gemm_time_sympy (
280
- M , K , N , gemm_dtype_input , mx_recipe_name
282
+ M , K , N , gemm_dtype_input , mx_recipe_name , gpu_name
281
283
)
282
284
gemm_grad_input_time_s = get_individual_gemm_time_sympy (
283
- M , N , K , gemm_dtype_grad_input , mx_recipe_name
285
+ M , N , K , gemm_dtype_grad_input , mx_recipe_name , gpu_name
284
286
)
285
287
gemm_grad_weight_time_s = get_individual_gemm_time_sympy (
286
- K , M , N , gemm_dtype_grad_weight , mx_recipe_name
288
+ K , M , N , gemm_dtype_grad_weight , mx_recipe_name , gpu_name
287
289
)
288
290
total = gemm_output_time_s + gemm_grad_input_time_s + gemm_grad_weight_time_s
289
291
return total
@@ -296,8 +298,9 @@ def get_float8_mem_sympy(
296
298
float8_recipe_name : Optional [str ],
297
299
mx_recipe_name : Optional [str ],
298
300
enable_fusion_modeling : bool ,
301
+ gpu_name : Optional [str ]= None
299
302
):
300
- specs = get_specs ()
303
+ specs = get_specs (gpu_name )
301
304
302
305
# there are three gemms in the fwd/bwd of a linear:
303
306
#
0 commit comments