Skip to content

Commit afd355e

Browse files
chowarfbfacebook-github-bot
authored andcommitted
Add gpu_name as a parameter in roofline estimate utils (#2657)
Summary: See title, this lets us get estimates without needing to run on the hardware we're getting estimates for Differential Revision: D79415350
1 parent 5f3ab63 commit afd355e

File tree

1 file changed

+15
-8
lines changed

1 file changed

+15
-8
lines changed

torchao/testing/training/roofline_utils.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,9 @@
6565
}
6666

6767

68-
def get_specs():
69-
gpu_name = torch.cuda.get_device_name(0)
68+
def get_specs(gpu_name: Optional[str] = None):
69+
if gpu_name is None:
70+
gpu_name = torch.cuda.get_device_name(0)
7071
return gpu_name_to_specs[gpu_name]
7172

7273

@@ -214,10 +215,14 @@ def get_tensor_memory_traffic_ovhd_s(
214215

215216

216217
def get_individual_gemm_time_sympy(
217-
M: sympy.Symbol, K: sympy.Symbol, N: sympy.Symbol, dtype, mx_recipe_name
218+
M: sympy.Symbol,
219+
K: sympy.Symbol,
220+
N: sympy.Symbol,
221+
dtype, mx_recipe_name,
222+
gpu_name: Optional[str] = None,
218223
) -> sympy.Symbol:
219224
# compute bound
220-
specs = get_specs()
225+
specs = get_specs(gpu_name)
221226
gemm_ops = 2 * M * K * N
222227
if dtype is torch.bfloat16:
223228
peak_tops = specs["bf16_peak_tops"]
@@ -265,6 +270,7 @@ def get_gemm_time_sympy(
265270
dtype,
266271
float8_recipe_name: Optional[str],
267272
mx_recipe_name: Optional[str],
273+
gpu_name: Optional[str],
268274
):
269275
# next: add rowwise_with_gw_hp here
270276
# note: this function is currently not super accurate for small shapes:
@@ -279,13 +285,13 @@ def get_gemm_time_sympy(
279285
gemm_dtype_grad_weight = torch.bfloat16
280286

281287
gemm_output_time_s = get_individual_gemm_time_sympy(
282-
M, K, N, gemm_dtype_input, mx_recipe_name
288+
M, K, N, gemm_dtype_input, mx_recipe_name, gpu_name
283289
)
284290
gemm_grad_input_time_s = get_individual_gemm_time_sympy(
285-
M, N, K, gemm_dtype_grad_input, mx_recipe_name
291+
M, N, K, gemm_dtype_grad_input, mx_recipe_name, gpu_name
286292
)
287293
gemm_grad_weight_time_s = get_individual_gemm_time_sympy(
288-
K, M, N, gemm_dtype_grad_weight, mx_recipe_name
294+
K, M, N, gemm_dtype_grad_weight, mx_recipe_name, gpu_name
289295
)
290296
total = gemm_output_time_s + gemm_grad_input_time_s + gemm_grad_weight_time_s
291297
return total
@@ -298,8 +304,9 @@ def get_float8_mem_sympy(
298304
float8_recipe_name: Optional[str],
299305
mx_recipe_name: Optional[str],
300306
enable_fusion_modeling: bool,
307+
gpu_name: Optional[str] = None,
301308
):
302-
specs = get_specs()
309+
specs = get_specs(gpu_name)
303310

304311
# there are three gemms in the fwd/bwd of a linear:
305312
#

0 commit comments

Comments
 (0)