Skip to content

Commit 73ff48f

Browse files
chowarfbfacebook-github-bot
authored andcommitted
Add gpu_name as a parameter in roofline estimate utils (#2657)
Summary: Pull Request resolved: #2657 See title, this lets us get estimates without needing to run on the hardware we're getting estimates for Differential Revision: D79415350
1 parent 7c5c0b5 commit 73ff48f

File tree

1 file changed

+11
-8
lines changed

1 file changed

+11
-8
lines changed

torchao/testing/training/roofline_utils.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,9 @@
6565
}
6666

6767

68-
def get_specs():
69-
gpu_name = torch.cuda.get_device_name(0)
68+
def get_specs(gpu_name: Optional[str]=None):
69+
if gpu_name is None:
70+
gpu_name = torch.cuda.get_device_name(0)
7071
return gpu_name_to_specs[gpu_name]
7172

7273

@@ -213,10 +214,10 @@ def get_tensor_memory_traffic_ovhd_s(
213214

214215

215216
def get_individual_gemm_time_sympy(
216-
M: sympy.Symbol, K: sympy.Symbol, N: sympy.Symbol, dtype, mx_recipe_name
217+
M: sympy.Symbol, K: sympy.Symbol, N: sympy.Symbol, dtype, mx_recipe_name, gpu_name: Optional[str]=None
217218
) -> sympy.Symbol:
218219
# compute bound
219-
specs = get_specs()
220+
specs = get_specs(gpu_name)
220221
gemm_ops = 2 * M * K * N
221222
if dtype is torch.bfloat16:
222223
peak_tops = specs["bf16_peak_tops"]
@@ -263,6 +264,7 @@ def get_gemm_time_sympy(
263264
dtype,
264265
float8_recipe_name: Optional[str],
265266
mx_recipe_name: Optional[str],
267+
gpu_name: Optional[str],
266268
):
267269
# next: add rowwise_with_gw_hp here
268270
# note: this function is currently not super accurate for small shapes:
@@ -277,13 +279,13 @@ def get_gemm_time_sympy(
277279
gemm_dtype_grad_weight = torch.bfloat16
278280

279281
gemm_output_time_s = get_individual_gemm_time_sympy(
280-
M, K, N, gemm_dtype_input, mx_recipe_name
282+
M, K, N, gemm_dtype_input, mx_recipe_name, gpu_name
281283
)
282284
gemm_grad_input_time_s = get_individual_gemm_time_sympy(
283-
M, N, K, gemm_dtype_grad_input, mx_recipe_name
285+
M, N, K, gemm_dtype_grad_input, mx_recipe_name, gpu_name
284286
)
285287
gemm_grad_weight_time_s = get_individual_gemm_time_sympy(
286-
K, M, N, gemm_dtype_grad_weight, mx_recipe_name
288+
K, M, N, gemm_dtype_grad_weight, mx_recipe_name, gpu_name
287289
)
288290
total = gemm_output_time_s + gemm_grad_input_time_s + gemm_grad_weight_time_s
289291
return total
@@ -296,8 +298,9 @@ def get_float8_mem_sympy(
296298
float8_recipe_name: Optional[str],
297299
mx_recipe_name: Optional[str],
298300
enable_fusion_modeling: bool,
301+
gpu_name: Optional[str]=None
299302
):
300-
specs = get_specs()
303+
specs = get_specs(gpu_name)
301304

302305
# there are three gemms in the fwd/bwd of a linear:
303306
#

0 commit comments

Comments
 (0)