From e70ba278505ce069c0cda3e29ba5b8019f46543a Mon Sep 17 00:00:00 2001 From: debasisdwivedy Date: Sat, 16 Aug 2025 12:44:48 +0530 Subject: [PATCH 1/2] setting MPS flag check for bf16 traning issue Signed-off-by: debasisdwivedy --- src/transformers/training_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index da0721eee0c9..db1f6f08c027 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1721,7 +1721,7 @@ def __post_init__(self): # cpu raise ValueError("Your setup doesn't support bf16/(cpu, tpu, neuroncore). You need torch>=1.10") elif not self.use_cpu: - if not is_torch_bf16_gpu_available() and not is_torch_xla_available(): # added for tpu support + if not is_torch_bf16_gpu_available() and not is_torch_xla_available() and not is_torch_mps_available(): # added for tpu support error_message = "Your setup doesn't support bf16/gpu." if is_torch_cuda_available(): error_message += " You need Ampere+ GPU with cuda>=11.0" From 46fbe3c22ec9dc0e994f493be145f7fd3affeb1f Mon Sep 17 00:00:00 2001 From: debasisdwivedy Date: Fri, 22 Aug 2025 19:05:22 +0530 Subject: [PATCH 2/2] fixing formatting error Signed-off-by: debasisdwivedy --- src/transformers/training_args.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index db1f6f08c027..2b066f988224 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1721,7 +1721,11 @@ def __post_init__(self): # cpu raise ValueError("Your setup doesn't support bf16/(cpu, tpu, neuroncore). You need torch>=1.10") elif not self.use_cpu: - if not is_torch_bf16_gpu_available() and not is_torch_xla_available() and not is_torch_mps_available(): # added for tpu support + if ( + not is_torch_bf16_gpu_available() + and not is_torch_xla_available() + and not is_torch_mps_available() + ): # added for tpu support error_message = "Your setup doesn't support bf16/gpu." if is_torch_cuda_available(): error_message += " You need Ampere+ GPU with cuda>=11.0"