Skip to content

Setting MPS flag check for bf16 training issue #40216

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

debasisdwivedy
Copy link

@debasisdwivedy debasisdwivedy commented Aug 16, 2025

What does this PR do?

Adds a check for MPS availability for training BF16 torch_dtype=torch.bfloat16

Fixes #39935

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@renet10
@ArthurZucker
@zach-huggingface
@SunMarc
@qgallouedec

I have attached a code sample to that was tested on my mac. Please feel free todo a round of testing from your side.

def train():
    import os
    import torch
    from transformers import AutoModelForCausalLM, AutoTokenizer
    from peft import LoraConfig, TaskType
    from trl import SFTTrainer,SFTConfig
    from dotenv import load_dotenv
    import json

    torch.mps.empty_cache()

    load_dotenv()
    from huggingface_hub import login
    login(token=os.getenv("HUGGINGFACE_API_TOKEN"))

    model_name = "google/gemma-3-270m-it"
    dataset_name = "<TAKE_ANY_SAMPLE_DATASET>"


    # 1. Load the model and tokenizer
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        attn_implementation='eager',
        token=os.getenv("HUGGINGFACE_API_TOKEN"),
        device_map='mps',
        torch_dtype=torch.bfloat16,
    )

    # Use AutoTokenizer to add special tokens
    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        token=os.getenv("HUGGINGFACE_API_TOKEN"),
    )

    ##############################################################################################################

    from datasets import load_dataset

    def preprocess(sample):
        messages = sample["messages"]
        return {"text": tokenizer.apply_chat_template(messages, add_generation_prompt=False,tokenize=False)}


    dataset  = load_dataset(dataset_name,split="train[:500]")
    dataset = dataset.rename_column("conversations", "messages")
    
    dataset = dataset.map(preprocess, remove_columns="messages")
    dataset = dataset.train_test_split(0.1)
    print(dataset)

    print(dataset["train"][5]["text"])


##############################################################################################################

    username="JOHN_DOE"# REPLACE with your Hugging Face username
    output_dir = "gemma-3-270m-it-fine-tuned" # The directory where the trained model checkpoints, logs, and other artifacts will be saved. It will also be the default name of the model when pushed to the hub if not redefined later.
    per_device_train_batch_size = 2
    per_device_eval_batch_size = 2
    gradient_accumulation_steps = 8
    logging_steps = 5
    learning_rate = 1e-5 # The initial learning rate for the optimizer.

    max_grad_norm = 1.0
    num_train_epochs=1
    warmup_ratio = 0.1
    lr_scheduler_type = "cosine"
    max_seq_length = 1024

    # 3. Configure PEFT with LoraConfig
    # Crucially, include the embedding layers in modules_to_save

    peft_config = LoraConfig(r=32,
                            lora_alpha=64,
                            lora_dropout=0.05,
                            target_modules=["gate_proj","q_proj","o_proj","k_proj","down_proj","up_proj","v_proj"],
                            modules_to_save=["embed_tokens", "lm_head"],
                            task_type=TaskType.CAUSAL_LM)

    training_arguments = SFTConfig(
        output_dir=output_dir,
        do_train=True,
        per_device_train_batch_size=per_device_train_batch_size,
        do_eval=True,
        per_device_eval_batch_size=per_device_eval_batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        save_strategy="no",
        eval_strategy="epoch",
        logging_steps=logging_steps,
        learning_rate=learning_rate,
        max_grad_norm=max_grad_norm,
        weight_decay=0.1,
        warmup_ratio=warmup_ratio,
        lr_scheduler_type=lr_scheduler_type,
        report_to="tensorboard",
        bf16=True,
        use_mps_device=True,
        seed=123,
        hub_private_repo=False,
        push_to_hub=False,
        num_train_epochs=num_train_epochs,
        gradient_checkpointing=True,
        gradient_checkpointing_kwargs={"use_reentrant": False},
        #packing=True,
        max_length=max_seq_length,
        remove_unused_columns=False,
        dataset_text_field = "text",
        optim="adamw_torch",              
        adam_beta1=0.9,
        adam_beta2=0.95,                 
        adam_epsilon=1e-8,
        label_smoothing_factor=0.1, 
    )
    ##############################################################################################################

    trainer = SFTTrainer(
        model=model,
        args=training_arguments,
        train_dataset=dataset["train"],
        eval_dataset=dataset["test"],
        processing_class=tokenizer,
        peft_config=peft_config,
    )

    ##############################################################################################################
    torch.mps.empty_cache()

    trainer.train()
    trainer.save_model(output_dir="./trainer_output")

Regards,

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! Can you just summarize: this is to enable bf16 training on mps device right? is_torch_bf16_gpu_available() does not cover mps I suppose

@debasisdwivedy
Copy link
Author

debasisdwivedy commented Aug 21, 2025

Yes correct. To make it work i have raised 2 PR's. Both have to be merged to make it work.

ACCELERATE_PACKAGE
TRANSFORMER_PACKAGE

I tested it with a small dataset on my MAC and it worked.

If you have higher memory mac available please test it before merging.

You can use the code provided and change the training parameters accordingly.

Regards,

@ArthurZucker
Copy link
Collaborator

can you fix the quality test please?!

Signed-off-by: debasisdwivedy <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Still getting "fp16 mixed precision requires a GPU (not 'mps')." error
2 participants