Skip to content

Commit c6de9b4

Browse files
authored
simplify Float8Linear (#2594)
Update [ghstack-poisoned]
1 parent 12ff479 commit c6de9b4

File tree

2 files changed

+6
-67
lines changed

2 files changed

+6
-67
lines changed

torchao/float8/README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -211,12 +211,12 @@ To reproduce these benchmarks, you can follow these steps:
211211
1. On a machine with 8 H100 GPUs, clone torchtitan and follow local installation [steps](https://github.com/pytorch/torchtitan?tab=readme-ov-file#installation),
212212
including [downloading a tokenizer](https://github.com/pytorch/torchtitan?tab=readme-ov-file#downloading-a-tokenizer).
213213
2. Install torchao following these [steps](https://github.com/pytorch/ao/tree/main?tab=readme-ov-file#installation).
214-
3. From the `torchao/float8/benchmarking/` directory, you can run the following commands to reproduce the benchmarks above:
215-
- bf16 + compile: `TORCHTITAN_ROOT=<path> ./float8_training_benchmark.sh`
216-
- float8 tensorwise with float8 all-gather + compile: `TORCHTITAN_ROOT=<path> FLOAT8_RECIPE_WITH_BEST_SETTINGS="tensorwise" ./float8_training_benchmark.sh`
217-
- float8 rowwise with bf16 all-gather + compile: `TORCHTITAN_ROOT=<path> FLOAT8_RECIPE_WITH_BEST_SETTINGS="rowwise" ./float8_training_benchmark.sh`
214+
3. From the `torchao/benchmarks/float8/training/` directory, you can run the following commands to reproduce the benchmarks above:
215+
- bf16 + compile: `TORCHTITAN_ROOT=<path> ./torchtitan_benchmark.sh`
216+
- float8 tensorwise with float8 all-gather + compile: `TORCHTITAN_ROOT=<path> FLOAT8_RECIPE_WITH_BEST_SETTINGS="tensorwise" ./torchtitan_benchmark.sh`
217+
- float8 rowwise with bf16 all-gather + compile: `TORCHTITAN_ROOT=<path> FLOAT8_RECIPE_WITH_BEST_SETTINGS="rowwise" ./torchtitan_benchmark.sh`
218218

219-
See the float8 training benchmarking [guide](.torchao/float8/benchmarking/README.md) for more details.
219+
See the float8 training benchmarking [guide](.torchao/benchmarks/float8/training/README.md) for more details.
220220

221221
# E2E training + inference flow
222222

torchao/float8/float8_linear.py

Lines changed: 1 addition & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -21,41 +21,10 @@
2121
GemmInputRole,
2222
LinearMMConfig,
2323
ScaledMMConfig,
24-
hp_tensor_and_scale_to_float8,
2524
)
26-
from torchao.float8.float8_utils import tensor_to_scale
2725
from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor
2826

2927

30-
def _get_weight_scale(
31-
weight: torch.Tensor,
32-
scaling_type_weight: ScalingType,
33-
config: Float8LinearConfig,
34-
) -> Optional[torch.Tensor]:
35-
if tensor_already_casted_to_fp8(weight):
36-
return None
37-
assert scaling_type_weight is ScalingType.DYNAMIC
38-
return tensor_to_scale(weight, config.cast_config_weight.target_dtype)
39-
40-
41-
def _cast_weight_to_float8_t(
42-
weight: torch.Tensor,
43-
config: Float8LinearConfig,
44-
linear_mm_config: LinearMMConfig,
45-
weight_scale: Optional[torch.Tensor] = None,
46-
) -> torch.Tensor:
47-
if tensor_already_casted_to_fp8(weight):
48-
return weight.t()
49-
weight_fp8 = hp_tensor_and_scale_to_float8(
50-
weight,
51-
weight_scale,
52-
config.cast_config_weight.target_dtype,
53-
linear_mm_config,
54-
gemm_input_role=GemmInputRole.WEIGHT,
55-
)
56-
return weight_fp8.t()
57-
58-
5928
@torch._dynamo.allow_in_graph
6029
class matmul_with_hp_or_float8_args(torch.autograd.Function):
6130
"""
@@ -307,39 +276,9 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
307276
autocast_dtype = torch.get_autocast_gpu_dtype()
308277
input = input.to(autocast_dtype)
309278

310-
has_any_axiswise_scaling = any(
311-
cc.scaling_granularity is ScalingGranularity.AXISWISE
312-
for cc in [
313-
self.config.cast_config_input,
314-
self.config.cast_config_weight,
315-
self.config.cast_config_grad_output,
316-
self.config.cast_config_input_for_grad_weight,
317-
self.config.cast_config_weight_for_grad_input,
318-
self.config.cast_config_grad_output_for_grad_weight,
319-
]
320-
)
321-
322-
weight_maybe_fp8_t = self.weight.t()
323-
324-
# TODO(future PR): check for axiswise scaling for input, weight,
325-
# grad_output separately instead of together
326-
if not has_any_axiswise_scaling:
327-
# TODO(future PR): now that `force_recompute_fp8_weight_in_bwd` is
328-
# deprecated, we can simplify the below code and unify the per-tensor
329-
# and per-axis paths further.
330-
weight_scale = _get_weight_scale(
331-
self.weight, self.scaling_type_weight, self.config
332-
)
333-
weight_maybe_fp8_t = _cast_weight_to_float8_t(
334-
self.weight,
335-
self.config,
336-
self.linear_mm_config,
337-
weight_scale,
338-
)
339-
340279
output = matmul_with_hp_or_float8_args.apply(
341280
input,
342-
weight_maybe_fp8_t,
281+
self.weight.t(),
343282
self.linear_mm_config,
344283
self.config,
345284
)

0 commit comments

Comments
 (0)