Skip to content

Commit 8689634

Browse files
chowarfbfacebook-github-bot
authored andcommitted
Add gpu_name as a parameter in roofline estimate utils
Summary: See title, this lets us get estimates without needing to run on the hardware we're getting estimates for Differential Revision: D79415350
1 parent 7c5c0b5 commit 8689634

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

torchao/testing/training/roofline_utils.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,9 @@
6565
}
6666

6767

68-
def get_specs():
69-
gpu_name = torch.cuda.get_device_name(0)
68+
def get_specs(gpu_name: Optional[str]=None):
69+
if gpu_name is None:
70+
gpu_name = torch.cuda.get_device_name(0)
7071
return gpu_name_to_specs[gpu_name]
7172

7273

@@ -213,10 +214,10 @@ def get_tensor_memory_traffic_ovhd_s(
213214

214215

215216
def get_individual_gemm_time_sympy(
216-
M: sympy.Symbol, K: sympy.Symbol, N: sympy.Symbol, dtype, mx_recipe_name
217+
M: sympy.Symbol, K: sympy.Symbol, N: sympy.Symbol, dtype, mx_recipe_name, gpu_name: Optional[str]=None
217218
) -> sympy.Symbol:
218219
# compute bound
219-
specs = get_specs()
220+
specs = get_specs(gpu_name)
220221
gemm_ops = 2 * M * K * N
221222
if dtype is torch.bfloat16:
222223
peak_tops = specs["bf16_peak_tops"]
@@ -296,8 +297,9 @@ def get_float8_mem_sympy(
296297
float8_recipe_name: Optional[str],
297298
mx_recipe_name: Optional[str],
298299
enable_fusion_modeling: bool,
300+
gpu_name: Optional[str]=None
299301
):
300-
specs = get_specs()
302+
specs = get_specs(gpu_name)
301303

302304
# there are three gemms in the fwd/bwd of a linear:
303305
#

0 commit comments

Comments
 (0)