-
Notifications
You must be signed in to change notification settings - Fork 982
[megatron] support multimodal model CPT/SFT/DPO (Full/LoRA) #5502
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 10 commits
abba497
0883b84
bdbaa9a
efd6f72
c0b28b4
0e60545
b349b56
6a77b0e
26d5f64
a502bbb
2e92219
6fad478
2878d15
0cbada0
3151603
93c4693
308d565
08320eb
44a95bd
50b2eb1
51f315a
ef66901
631691c
587298e
a1cb64b
798bdd4
5953506
0508aaf
6a431aa
3f35a86
0156a0c
4c810cb
d90d282
f89a2bb
dc803a0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,31 @@ | ||
| # 4 * 56GiB; 2.3s/it | ||
| PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ | ||
| NPROC_PER_NODE=4 \ | ||
| MAX_PIXELS=1003520 \ | ||
| CUDA_VISIBLE_DEVICES=0,1,2,3 \ | ||
| megatron sft \ | ||
| --load Qwen2.5-VL-7B-Instruct-mcore \ | ||
| --dataset 'AI-ModelScope/LaTeX_OCR:human_handwrite' \ | ||
| --tensor_model_parallel_size 2 \ | ||
| --sequence_parallel true \ | ||
| --packing true \ | ||
| --split_dataset_ratio 0.01 \ | ||
| --micro_batch_size 1 \ | ||
| --global_batch_size 4 \ | ||
| --recompute_granularity full \ | ||
| --recompute_method uniform \ | ||
| --recompute_num_layers 1 \ | ||
| --finetune true \ | ||
| --cross_entropy_loss_fusion true \ | ||
| --lr 1e-5 \ | ||
| --lr_warmup_fraction 0.05 \ | ||
| --min_lr 1e-6 \ | ||
| --max_epochs 1 \ | ||
| --save megatron_output/Qwen2.5-VL-7B-Instruct \ | ||
| --save_interval 200 \ | ||
| --vit_gradient_checkpointing true \ | ||
| --max_length 2048 \ | ||
| --num_workers 4 \ | ||
| --no_save_optim true \ | ||
| --no_save_rng true \ | ||
| --dataset_num_proc 8 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,4 @@ | ||
| # Copyright (c) Alibaba, Inc. and its affiliates. | ||
| from . import gpt | ||
| from . import gpt, qwen2_5_vl | ||
| from .constant import MegatronModelType | ||
| from .register import MegatronModelMeta, get_megatron_model_meta, register_megatron_model |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,4 @@ | ||
| # Copyright (c) Alibaba, Inc. and its affiliates. | ||
| class MegatronModelType: | ||
| gpt = 'gpt' | ||
| qwen2_5_vl = 'qwen2_5_vl' |
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -4,13 +4,14 @@ | |||||||||
| from typing import Any, Dict, Literal, Optional | ||||||||||
|
|
||||||||||
| import torch | ||||||||||
| from megatron.core import InferenceParams | ||||||||||
| from megatron.core import InferenceParams, mpu | ||||||||||
| from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk | ||||||||||
| from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding | ||||||||||
| from megatron.core.models.gpt import GPTModel as McoreGPTModel | ||||||||||
| from megatron.core.packed_seq_params import PackedSeqParams | ||||||||||
| from megatron.core.transformer.spec_utils import ModuleSpec | ||||||||||
| from megatron.core.transformer.transformer_config import TransformerConfig | ||||||||||
| from megatron.training import get_args | ||||||||||
|
|
||||||||||
| from swift.utils import get_logger | ||||||||||
| from .rope import dynamic_rope_update, get_rope_inv_freq | ||||||||||
|
|
@@ -91,6 +92,11 @@ def __init__( | |||||||||
| logger.warning('`apply_rope_fusion` does not support `attention_scaling`. ' | ||||||||||
| f'Setting `config.apply_rope_fusion`: {config.apply_rope_fusion}') | ||||||||||
|
|
||||||||||
| args = get_args() | ||||||||||
| self.visual = None | ||||||||||
| if args.megatron_model_meta.visual is not None: | ||||||||||
| self.visual = args.megatron_model_meta.visual(config) | ||||||||||
|
|
||||||||||
| @contextmanager | ||||||||||
| def _patch_apply_rotary_pos_emb(self): | ||||||||||
| if self.attention_scaling == 1.: | ||||||||||
|
|
@@ -138,10 +144,19 @@ def forward( | |||||||||
| # Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input. | ||||||||||
|
|
||||||||||
| # Decoder embedding. | ||||||||||
| args = get_args() | ||||||||||
| if decoder_input is not None: | ||||||||||
| pass | ||||||||||
| elif self.pre_process: | ||||||||||
| decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids) | ||||||||||
| if self.visual is not None: | ||||||||||
| if args.tensor_model_parallel_size > 1 and args.sequence_parallel: | ||||||||||
| input_ids = input_ids.chunk( | ||||||||||
| args.tensor_model_parallel_size, dim=-1)[mpu.get_tensor_model_parallel_rank()] | ||||||||||
| kwargs.update({'input_ids': input_ids}) | ||||||||||
| decoder_input = decoder_input.transpose(0, 1) | ||||||||||
| decoder_input = self.visual.get_inputs_embeds(decoder_input, **kwargs) | ||||||||||
| decoder_input = decoder_input.transpose(0, 1) | ||||||||||
| else: | ||||||||||
| # intermediate stage of pipeline | ||||||||||
| # decoder will get hidden_states from encoder.input_tensor | ||||||||||
|
|
@@ -172,6 +187,13 @@ def forward( | |||||||||
| rotary_seq_len, | ||||||||||
| packed_seq=packed_seq_params is not None and packed_seq_params.qkv_format == 'thd', | ||||||||||
| ) | ||||||||||
| elif self.position_embedding_type in 'mrope': | ||||||||||
| if self.training or not self.config.flash_decode: | ||||||||||
| rotary_pos_emb = self.rotary_pos_emb(position_ids, self.mrope_section) | ||||||||||
| else: | ||||||||||
| # Flash decoding uses precomputed cos and sin for RoPE | ||||||||||
| raise NotImplementedError('Flash decoding uses precomputed cos and sin for RoPE, not implmented in ' | ||||||||||
| 'MultimodalRotaryEmbedding yet.') | ||||||||||
|
||||||||||
| raise NotImplementedError('Flash decoding uses precomputed cos and sin for RoPE, not implmented in ' | |
| 'MultimodalRotaryEmbedding yet.') | |
| raise NotImplementedError('Flash decoding uses precomputed cos and sin for RoPE, not implemented in ' | |
| 'MultimodalRotaryEmbedding yet.') |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
| from swift.llm import ModelType | ||
| from ..constant import MegatronModelType | ||
| from ..gpt import GptMegatronModelMeta | ||
| from ..register import MegatronModelMeta, register_megatron_model | ||
| from .convert import convert_hf2mcore_qwen2_5_vl, convert_mcore2hf_qwen2_5_vl | ||
| from .vit import Qwen2_5VL_Vit | ||
|
|
||
| register_megatron_model( | ||
| GptMegatronModelMeta( | ||
| MegatronModelType.qwen2_5_vl, [ | ||
| ModelType.qwen2_5_vl, | ||
| ], | ||
| convert_hf2mcore=convert_hf2mcore_qwen2_5_vl, | ||
| convert_mcore2hf=convert_mcore2hf_qwen2_5_vl, | ||
| visual=Qwen2_5VL_Vit)) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,31 @@ | ||
| from megatron.training import get_args | ||
|
|
||
| from swift.utils import deep_getattr | ||
| from ..gpt.hf2mcore import set_layer_state as set_layer_state_hf2mcore | ||
| from ..gpt.mcore2hf import set_layer_state as set_layer_state_mcore2hf | ||
|
|
||
|
|
||
| def convert_hf2mcore_qwen2_5_vl(hf_model, mg_model): | ||
| language_model = hf_model.model.language_model | ||
| args = get_args() | ||
| # language_model | ||
| mg_model.embedding.word_embeddings.weight.data.copy_(language_model.embed_tokens.weight) | ||
| if args.untie_embeddings_and_output_weights: | ||
| mg_model.output_layer.weight.data.copy_(hf_model.lm_head.weight) | ||
| mg_model.decoder.final_layernorm.weight.data.copy_(language_model.norm.weight) | ||
| for layer_idx in range(args.num_layers): | ||
| set_layer_state_hf2mcore(args, mg_model, language_model, layer_idx) | ||
| mg_model.visual.model.load_state_dict(hf_model.model.visual.state_dict()) | ||
|
|
||
|
|
||
| def convert_mcore2hf_qwen2_5_vl(hf_model, mg_model): | ||
| language_model = hf_model.model.language_model | ||
| args = get_args() | ||
| # language_model | ||
| language_model.embed_tokens.weight.data.copy_(mg_model.embedding.word_embeddings.weight) | ||
| if args.untie_embeddings_and_output_weights: | ||
| hf_model.lm_head.weight.data.copy_(mg_model.output_layer.weight) | ||
| language_model.norm.weight.data.copy_(mg_model.decoder.final_layernorm.weight) | ||
| for layer_idx in range(args.num_layers): | ||
| set_layer_state_mcore2hf(args, mg_model, language_model, layer_idx) | ||
| hf_model.model.visual.load_state_dict(mg_model.visual.model.state_dict()) |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,90 @@ | ||||||
| from contextlib import contextmanager | ||||||
|
|
||||||
| import torch | ||||||
| from megatron.core.models.huggingface import HuggingFaceModule | ||||||
| from megatron.training import get_args | ||||||
|
|
||||||
| from swift.llm import get_model_tokenizer, to_device | ||||||
|
|
||||||
|
|
||||||
| @contextmanager | ||||||
| def patch_device_map_meta(model_cls): | ||||||
| __origin_init__ = model_cls.__init__ | ||||||
|
|
||||||
| def __init__(self, *args, **kwargs): | ||||||
| with torch.device('meta'): | ||||||
| __origin_init__(self, *args, **kwargs) | ||||||
|
|
||||||
| model_cls.__init__ = __init__ | ||||||
| try: | ||||||
| yield | ||||||
| finally: | ||||||
| model_cls.__init__ = __origin_init__ | ||||||
|
|
||||||
|
|
||||||
| class Qwen2_5VL_Vit(HuggingFaceModule): | ||||||
|
|
||||||
| def __init__(self, config): | ||||||
| from transformers.models.qwen2_5_vl import Qwen2_5_VLTextModel | ||||||
| super().__init__(config) | ||||||
| args = get_args() | ||||||
| model_dir = args.model_info.model_dir | ||||||
| kwargs = {'attn_impl': 'flash_attn'} if args.attention_backend.name == 'flash' else {} | ||||||
| with patch_device_map_meta(Qwen2_5_VLTextModel): | ||||||
| model, _ = get_model_tokenizer(model_dir, args.torch_dtype, return_dummy_model=True, **kwargs) | ||||||
|
||||||
| model, _ = get_model_tokenizer(model_dir, args.torch_dtype, return_dummy_model=True, **kwargs) | |
| model, self.processor = get_model_tokenizer(model_dir, args.torch_dtype, return_dummy_model=True, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The condition
self.position_embedding_type in 'mrope'is likely a bug. It checks ifself.position_embedding_typeis one of the characters 'm', 'r', 'o', 'p', 'e', not the string 'mrope'. This works by coincidence for now because no other position embedding types share these characters, but it's fragile and incorrect. It should be changed to a direct string comparison for correctness and clarity.