Skip to content

Commit 4b119ed

Browse files
authored
update float8 readme with more recent performance numbers (#2580)
1. run roofline script for tensorwise and rowwise recipes on recent torch and torchao 2. add section for rowwise_with_gw_hp
1 parent b6ef500 commit 4b119ed

File tree

1 file changed

+16
-14
lines changed

1 file changed

+16
-14
lines changed

torchao/float8/README.md

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -97,30 +97,32 @@ on using `torchao.float8` in a distributed setting.
9797

9898
A common question about float8 training is "when is float8 linear faster vs bfloat16?". Given the M, K, N of the forward pass through your linear, you can reference the tables below for a microbenchmark based speedup estimate on NVIDIA H100:
9999

100-
### Tensorwise scaling
100+
### tensorwise scaling
101101

102-
<img width="805" alt="float8_speedup" src="https://github.com/user-attachments/assets/5c5f2817-7eb7-4cab-bd03-49fe70cd31a8">
102+
<img width="753" height="773" alt="Image" src="https://github.com/user-attachments/assets/e46c671a-ed35-41b4-b17c-50caf1629ecb" />
103103

104-
Example 1 (small shapes):
105-
* forward input tensor size 1024x2048, linear weight size 2048x1024; M, K, N = 1024, 2048, 1024
106-
* benchmark speedup is 0.80
107-
* recommendation: leave this linear in bfloat16, the shapes are too small to benefit from float8 compute
104+
```lang=shell
105+
# reproduction: run the script below
106+
python benchmarks/float8/float8_roofline.py your_output_filename.csv --shape_gen_name sweep
107+
```
108108

109-
Example 2 (large shapes):
110-
* forward input tensor size 4096x8192, linear weight size 8192x16384; M, K, N = 4096, 8192, 16384
111-
* benchmark speedup is 1.39
112-
* recommendation: enable float8 for this linear to get a speedup
109+
### rowwise scaling
113110

114-
To reproduce the raw data for table above, you can run the following script
111+
<img width="755" height="778" alt="Image" src="https://github.com/user-attachments/assets/7d70ba36-f480-459f-b5c0-797895332631" />
115112

116113
```lang=shell
117-
python benchmarks/float8/float8_roofline.py your_output_filename.csv --shape_gen_name sweep
114+
# reproduction: run the script below
115+
python benchmarks/float8/float8_roofline.py your_output_filename.csv --shape_gen_name sweep --float8_recipe_name rowwise
118116
```
119117

120-
### Rowwise scaling
118+
### rowwise_with_gw_hp scaling
121119

122-
<img width="805" alt="float8_rowwise_speedup" src="../../docs/static/fp8-rowwise-perf.png" />
120+
<img width="750" height="797" alt="Image" src="https://github.com/user-attachments/assets/e4479abc-1aca-436d-a142-60e5e804ff10" />
123121

122+
```lang=shell
123+
# reproduction: run the script below
124+
python benchmarks/float8/float8_roofline.py your_output_filename.csv --shape_gen_name sweep --float8_recipe_name rowwise_with_gw_hp
125+
```
124126

125127
## Derivation
126128

0 commit comments

Comments
 (0)