6565}
6666
6767
68- def get_specs ():
69- gpu_name = torch .cuda .get_device_name (0 )
68+ def get_specs (gpu_name : Optional [str ] = None ):
69+ if gpu_name is None :
70+ gpu_name = torch .cuda .get_device_name (0 )
7071 return gpu_name_to_specs [gpu_name ]
7172
7273
@@ -214,10 +215,14 @@ def get_tensor_memory_traffic_ovhd_s(
214215
215216
216217def get_individual_gemm_time_sympy (
217- M : sympy .Symbol , K : sympy .Symbol , N : sympy .Symbol , dtype , mx_recipe_name
218+ M : sympy .Symbol ,
219+ K : sympy .Symbol ,
220+ N : sympy .Symbol ,
221+ dtype , mx_recipe_name ,
222+ gpu_name : Optional [str ] = None ,
218223) -> sympy .Symbol :
219224 # compute bound
220- specs = get_specs ()
225+ specs = get_specs (gpu_name )
221226 gemm_ops = 2 * M * K * N
222227 if dtype is torch .bfloat16 :
223228 peak_tops = specs ["bf16_peak_tops" ]
@@ -265,6 +270,7 @@ def get_gemm_time_sympy(
265270 dtype ,
266271 float8_recipe_name : Optional [str ],
267272 mx_recipe_name : Optional [str ],
273+ gpu_name : Optional [str ],
268274):
269275 # next: add rowwise_with_gw_hp here
270276 # note: this function is currently not super accurate for small shapes:
@@ -279,13 +285,13 @@ def get_gemm_time_sympy(
279285 gemm_dtype_grad_weight = torch .bfloat16
280286
281287 gemm_output_time_s = get_individual_gemm_time_sympy (
282- M , K , N , gemm_dtype_input , mx_recipe_name
288+ M , K , N , gemm_dtype_input , mx_recipe_name , gpu_name
283289 )
284290 gemm_grad_input_time_s = get_individual_gemm_time_sympy (
285- M , N , K , gemm_dtype_grad_input , mx_recipe_name
291+ M , N , K , gemm_dtype_grad_input , mx_recipe_name , gpu_name
286292 )
287293 gemm_grad_weight_time_s = get_individual_gemm_time_sympy (
288- K , M , N , gemm_dtype_grad_weight , mx_recipe_name
294+ K , M , N , gemm_dtype_grad_weight , mx_recipe_name , gpu_name
289295 )
290296 total = gemm_output_time_s + gemm_grad_input_time_s + gemm_grad_weight_time_s
291297 return total
@@ -298,8 +304,9 @@ def get_float8_mem_sympy(
298304 float8_recipe_name : Optional [str ],
299305 mx_recipe_name : Optional [str ],
300306 enable_fusion_modeling : bool ,
307+ gpu_name : Optional [str ] = None
301308):
302- specs = get_specs ()
309+ specs = get_specs (gpu_name )
303310
304311 # there are three gemms in the fwd/bwd of a linear:
305312 #
0 commit comments