|
21 | 21 | GemmInputRole,
|
22 | 22 | LinearMMConfig,
|
23 | 23 | ScaledMMConfig,
|
24 |
| - hp_tensor_and_scale_to_float8, |
25 | 24 | )
|
26 |
| -from torchao.float8.float8_utils import tensor_to_scale |
27 | 25 | from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor
|
28 | 26 |
|
29 | 27 |
|
30 |
| -def _get_weight_scale( |
31 |
| - weight: torch.Tensor, |
32 |
| - scaling_type_weight: ScalingType, |
33 |
| - config: Float8LinearConfig, |
34 |
| -) -> Optional[torch.Tensor]: |
35 |
| - if tensor_already_casted_to_fp8(weight): |
36 |
| - return None |
37 |
| - assert scaling_type_weight is ScalingType.DYNAMIC |
38 |
| - return tensor_to_scale(weight, config.cast_config_weight.target_dtype) |
39 |
| - |
40 |
| - |
41 |
| -def _cast_weight_to_float8_t( |
42 |
| - weight: torch.Tensor, |
43 |
| - config: Float8LinearConfig, |
44 |
| - linear_mm_config: LinearMMConfig, |
45 |
| - weight_scale: Optional[torch.Tensor] = None, |
46 |
| -) -> torch.Tensor: |
47 |
| - if tensor_already_casted_to_fp8(weight): |
48 |
| - return weight.t() |
49 |
| - weight_fp8 = hp_tensor_and_scale_to_float8( |
50 |
| - weight, |
51 |
| - weight_scale, |
52 |
| - config.cast_config_weight.target_dtype, |
53 |
| - linear_mm_config, |
54 |
| - gemm_input_role=GemmInputRole.WEIGHT, |
55 |
| - ) |
56 |
| - return weight_fp8.t() |
57 |
| - |
58 |
| - |
59 | 28 | @torch._dynamo.allow_in_graph
|
60 | 29 | class matmul_with_hp_or_float8_args(torch.autograd.Function):
|
61 | 30 | """
|
@@ -307,39 +276,9 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
|
307 | 276 | autocast_dtype = torch.get_autocast_gpu_dtype()
|
308 | 277 | input = input.to(autocast_dtype)
|
309 | 278 |
|
310 |
| - has_any_axiswise_scaling = any( |
311 |
| - cc.scaling_granularity is ScalingGranularity.AXISWISE |
312 |
| - for cc in [ |
313 |
| - self.config.cast_config_input, |
314 |
| - self.config.cast_config_weight, |
315 |
| - self.config.cast_config_grad_output, |
316 |
| - self.config.cast_config_input_for_grad_weight, |
317 |
| - self.config.cast_config_weight_for_grad_input, |
318 |
| - self.config.cast_config_grad_output_for_grad_weight, |
319 |
| - ] |
320 |
| - ) |
321 |
| - |
322 |
| - weight_maybe_fp8_t = self.weight.t() |
323 |
| - |
324 |
| - # TODO(future PR): check for axiswise scaling for input, weight, |
325 |
| - # grad_output separately instead of together |
326 |
| - if not has_any_axiswise_scaling: |
327 |
| - # TODO(future PR): now that `force_recompute_fp8_weight_in_bwd` is |
328 |
| - # deprecated, we can simplify the below code and unify the per-tensor |
329 |
| - # and per-axis paths further. |
330 |
| - weight_scale = _get_weight_scale( |
331 |
| - self.weight, self.scaling_type_weight, self.config |
332 |
| - ) |
333 |
| - weight_maybe_fp8_t = _cast_weight_to_float8_t( |
334 |
| - self.weight, |
335 |
| - self.config, |
336 |
| - self.linear_mm_config, |
337 |
| - weight_scale, |
338 |
| - ) |
339 |
| - |
340 | 279 | output = matmul_with_hp_or_float8_args.apply(
|
341 | 280 | input,
|
342 |
| - weight_maybe_fp8_t, |
| 281 | + self.weight.t(), |
343 | 282 | self.linear_mm_config,
|
344 | 283 | self.config,
|
345 | 284 | )
|
|
0 commit comments