|  | 
| 11 | 11 | _use_top_left_mask = flash_attn_supports_top_left_mask() | 
| 12 | 12 | 
 | 
| 13 | 13 | 
 | 
|  | 14 | +def get_target_dtype(query: torch.Tensor, module: torch.nn.Module) -> torch.dtype: | 
|  | 15 | +    """If the query is in float32, return a target dtype compatible with flash attention. Return None otherwise.""" | 
|  | 16 | +    if query.dtype == torch.float32: | 
|  | 17 | +        if torch.is_autocast_enabled(): | 
|  | 18 | +            return torch.get_autocast_gpu_dtype() | 
|  | 19 | +        # Handle the case where the model is quantized | 
|  | 20 | +        elif hasattr(module.config, "_pre_quantization_dtype"): | 
|  | 21 | +            return module.config._pre_quantization_dtype | 
|  | 22 | +        else: | 
|  | 23 | +            return next(layer for layer in module.modules() if isinstance(layer, torch.nn.Linear)).weight.dtype | 
|  | 24 | +    return None | 
|  | 25 | + | 
|  | 26 | + | 
| 14 | 27 | def flash_attention_forward( | 
| 15 | 28 |     module: torch.nn.Module, | 
| 16 | 29 |     query: torch.Tensor, | 
| @@ -48,15 +61,7 @@ def flash_attention_forward( | 
| 48 | 61 |     # cast them back in the correct dtype just to be sure everything works as expected. | 
| 49 | 62 |     # This might slowdown training & inference so it is recommended to not cast the LayerNorms | 
| 50 | 63 |     # in fp32. (usually our RMSNorm modules handle it correctly) | 
| 51 |  | -    target_dtype = None | 
| 52 |  | -    if query.dtype == torch.float32: | 
| 53 |  | -        if torch.is_autocast_enabled(): | 
| 54 |  | -            target_dtype = torch.get_autocast_gpu_dtype() | 
| 55 |  | -        # Handle the case where the model is quantized | 
| 56 |  | -        elif hasattr(module.config, "_pre_quantization_dtype"): | 
| 57 |  | -            target_dtype = module.config._pre_quantization_dtype | 
| 58 |  | -        else: | 
| 59 |  | -            target_dtype = next(layer for layer in module.modules() if isinstance(layer, torch.nn.Linear)).weight.dtype | 
|  | 64 | +    target_dtype = get_target_dtype(query, module) | 
| 60 | 65 | 
 | 
| 61 | 66 |     # Instead of relying on the value set in the module directly, we use the is_causal passed in kwargs if it is presented | 
| 62 | 67 |     is_causal = kwargs.pop("is_causal", None) | 
|  | 
0 commit comments