Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
abba497
update
Jintao-Huang Aug 22, 2025
0883b84
update
Jintao-Huang Aug 24, 2025
bdbaa9a
update
Jintao-Huang Aug 24, 2025
efd6f72
Merge branch 'main' into support_megatron_multimodal
Jintao-Huang Aug 26, 2025
c0b28b4
Merge branch 'main' into support_megatron_multimodal
Jintao-Huang Aug 29, 2025
0e60545
update
Jintao-Huang Aug 29, 2025
b349b56
update
Jintao-Huang Aug 31, 2025
6a77b0e
update
Jintao-Huang Aug 31, 2025
26d5f64
Merge branch 'main' into support_megatron_multimodal
Jintao-Huang Aug 31, 2025
a502bbb
update
Jintao-Huang Aug 31, 2025
2e92219
Merge branch 'main' into support_megatron_multimodal
Jintao-Huang Sep 1, 2025
6fad478
fix
Jintao-Huang Sep 1, 2025
2878d15
update
Jintao-Huang Sep 1, 2025
0cbada0
update
Jintao-Huang Sep 1, 2025
3151603
fix
Jintao-Huang Sep 1, 2025
93c4693
update
Jintao-Huang Sep 1, 2025
308d565
update
Jintao-Huang Sep 1, 2025
08320eb
update
Jintao-Huang Sep 1, 2025
44a95bd
fix
Jintao-Huang Sep 1, 2025
50b2eb1
lint pass
Jintao-Huang Sep 1, 2025
51f315a
fix cp
Jintao-Huang Sep 1, 2025
ef66901
update
Jintao-Huang Sep 1, 2025
631691c
lint pass
Jintao-Huang Sep 1, 2025
587298e
update
Jintao-Huang Sep 2, 2025
a1cb64b
update
Jintao-Huang Sep 2, 2025
798bdd4
fix
Jintao-Huang Sep 2, 2025
5953506
fix
Jintao-Huang Sep 2, 2025
0508aaf
lint pass
Jintao-Huang Sep 2, 2025
6a431aa
update
Jintao-Huang Sep 2, 2025
3f35a86
update
Jintao-Huang Sep 2, 2025
0156a0c
update
Jintao-Huang Sep 2, 2025
4c810cb
fix
Jintao-Huang Sep 2, 2025
d90d282
fix
Jintao-Huang Sep 2, 2025
f89a2bb
update
Jintao-Huang Sep 2, 2025
dc803a0
update
Jintao-Huang Sep 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions examples/megatron/multimodal/dense.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# 4 * 56GiB; 2.3s/it
PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \
NPROC_PER_NODE=4 \
MAX_PIXELS=1003520 \
CUDA_VISIBLE_DEVICES=0,1,2,3 \
megatron sft \
--load Qwen2.5-VL-7B-Instruct-mcore \
--dataset 'AI-ModelScope/LaTeX_OCR:human_handwrite' \
--tensor_model_parallel_size 2 \
--sequence_parallel true \
--packing true \
--split_dataset_ratio 0.01 \
--micro_batch_size 1 \
--global_batch_size 4 \
--recompute_granularity full \
--recompute_method uniform \
--recompute_num_layers 1 \
--finetune true \
--cross_entropy_loss_fusion true \
--lr 1e-5 \
--lr_warmup_fraction 0.05 \
--min_lr 1e-6 \
--max_epochs 1 \
--save megatron_output/Qwen2.5-VL-7B-Instruct \
--save_interval 200 \
--vit_gradient_checkpointing true \
--max_length 2048 \
--num_workers 4 \
--no_save_optim true \
--no_save_rng true \
--dataset_num_proc 8
1 change: 1 addition & 0 deletions examples/train/multimodal/lora_llm_full_vit/sft.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# 4 * 22GiB
# vit/merger lr 1e-5; llm lora lr 1e-4
# Note: not support resume_from_checkpoint (only support resume_only_model)
NPROC_PER_NODE=4 \
CUDA_VISIBLE_DEVICES=0,1,2,3 \
MAX_PIXELS=1003520 \
Expand Down
11 changes: 10 additions & 1 deletion swift/megatron/argument/megatron_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ class ExtraMegatronArguments(RLHFMegatronArgumentsMixin, MegatronTunerMixin):
partial_rotary_factor: Optional[float] = None
use_shared_expert_gate: Optional[bool] = None

# visual
vit_gradient_checkpointing: bool = True
gradient_checkpointing_kwargs: Optional[Union[dict, str]] = None


@dataclass
class MegatronArguments(ExtraMegatronArguments):
Expand Down Expand Up @@ -185,7 +189,8 @@ class MegatronArguments(ExtraMegatronArguments):
group_query_attention: Optional[bool] = None
num_query_groups: Optional[int] = None
max_position_embeddings: Optional[int] = None
position_embedding_type: Literal['learned_absolute', 'rope', 'mrope', 'relative', 'none'] = 'rope'
position_embedding_type: Optional[Literal['learned_absolute', 'rope', 'mrope', 'relative', 'none']] = None
mrope_section: Optional[List[int]] = None
rotary_base: Optional[int] = None
rotary_percent: float = 1.
rotary_interleaved: Optional[bool] = None
Expand Down Expand Up @@ -376,10 +381,14 @@ def __post_init__(self):
self.rope_scaling = json_parse_to_dict(self.rope_scaling)
if 'type' in self.rope_scaling and 'rope_type' not in self.rope_scaling:
self.rope_scaling['rope_type'] = self.rope_scaling['type']
if self.gradient_checkpointing_kwargs is not None:
self.gradient_checkpointing_kwargs = json_parse_to_dict(self.gradient_checkpointing_kwargs)
if self.eval_interval is None:
self.eval_interval = self.save_interval
if self.seq_length is None:
self.seq_length = self.max_position_embeddings
if self.position_embedding_type is None:
self.position_embedding_type = 'rope'
if self.tensorboard_dir is None and self.save is not None:
self.tensorboard_dir = f'{self.save}/runs'
self._init_moe()
Expand Down
3 changes: 3 additions & 0 deletions swift/megatron/argument/train_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ def init_model_args(self, tokenizer, config):
setattr(self, k, v)
MegatronArguments.__post_init__(self)
self.extra_args = self.parse_to_megatron()
self.extra_args['model_info'] = self.model_info
self.extra_args['model_meta'] = self.model_meta
self.extra_args['megatron_model_meta'] = self.megatron_model_meta

def _init_save(self):
init_process_group(backend=self.ddp_backend, timeout=self.ddp_timeout)
Expand Down
2 changes: 1 addition & 1 deletion swift/megatron/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from . import gpt
from . import gpt, qwen2_5_vl
from .constant import MegatronModelType
from .register import MegatronModelMeta, get_megatron_model_meta, register_megatron_model
1 change: 1 addition & 0 deletions swift/megatron/model/constant.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
class MegatronModelType:
gpt = 'gpt'
qwen2_5_vl = 'qwen2_5_vl'
14 changes: 14 additions & 0 deletions swift/megatron/model/gpt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from dataclasses import dataclass
from typing import Any, Callable, Dict

from torch import nn
from transformers import PretrainedConfig

from swift.llm import ModelType
from ..constant import MegatronModelType
from ..register import MegatronModelMeta, register_megatron_model
Expand Down Expand Up @@ -53,3 +59,11 @@
ModelType.glm4_5,
ModelType.deepseek_v3_1,
], model_provider, convert_gpt_hf_config, convert_mcore2hf, convert_hf2mcore))


@dataclass
class GptMegatronModelMeta(MegatronModelMeta):
model_provider: Callable[[], nn.Module] = model_provider
convert_hf_config: Callable[[PretrainedConfig], Dict[str, Any]] = convert_gpt_hf_config
convert_mcore2hf: Callable[[nn.Module, nn.Module], None] = convert_mcore2hf
convert_hf2mcore: Callable[[nn.Module, nn.Module], None] = convert_hf2mcore
3 changes: 3 additions & 0 deletions swift/megatron/model/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ def convert_gpt_hf_config(config) -> Dict[str, Any]:
res['rotary_interleaved'] = True
elif architectures == 'Glm4MoeForCausalLM':
res['moe_router_score_function'] = 'sigmoid'
elif architectures == 'Qwen2_5_VLForConditionalGeneration':
res['position_embedding_type'] = 'mrope'
res['mrope_section'] = res['rope_scaling']['mrope_section']
if first_k_dense_replace is not None:
res['moe_layer_freq'] = f'[0]*{first_k_dense_replace}+[1]*{res["num_layers"] - first_k_dense_replace}'
if res.get('moe_router_score_function', 'softmax') == 'sigmoid':
Expand Down
4 changes: 2 additions & 2 deletions swift/megatron/model/gpt/hf2mcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def set_mlp_state(args, mg_mlp, hf_mlp):

def set_layer_state(args, mg_model, hf_model, layer_idx):
mg_layer = mg_model.decoder.layers[layer_idx]
hf_layer = hf_model.model.layers[layer_idx]
hf_layer = hf_model.layers[layer_idx]
if args.multi_latent_attention:
set_mla_attn_state(args, mg_layer.self_attention, hf_layer.self_attn)
mg_layer.input_layernorm.weight.data.copy_(hf_layer.input_layernorm.weight)
Expand All @@ -115,4 +115,4 @@ def convert_hf2mcore(hf_model, mg_model):
mg_model.output_layer.weight.data.copy_(hf_model.lm_head.weight)
mg_model.decoder.final_layernorm.weight.data.copy_(hf_model.model.norm.weight)
for layer_idx in range(args.num_layers):
set_layer_state(args, mg_model, hf_model, layer_idx)
set_layer_state(args, mg_model, hf_model.model, layer_idx)
4 changes: 2 additions & 2 deletions swift/megatron/model/gpt/mcore2hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def set_mlp_state(args, mg_mlp, hf_mlp):

def set_layer_state(args, mg_model, hf_model, layer_idx):
mg_layer = mg_model.decoder.layers[layer_idx]
hf_layer = hf_model.model.layers[layer_idx]
hf_layer = hf_model.layers[layer_idx]

if args.multi_latent_attention:
set_mla_attn_state(args, mg_layer.self_attention, hf_layer.self_attn)
Expand All @@ -113,4 +113,4 @@ def convert_mcore2hf(hf_model, mg_model):
hf_model.lm_head.weight.data.copy_(mg_model.output_layer.weight)
hf_model.model.norm.weight.data.copy_(mg_model.decoder.final_layernorm.weight)
for layer_idx in range(args.num_layers):
set_layer_state(args, mg_model, hf_model, layer_idx)
set_layer_state(args, mg_model, hf_model.model, layer_idx)
24 changes: 23 additions & 1 deletion swift/megatron/model/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
from typing import Any, Dict, Literal, Optional

import torch
from megatron.core import InferenceParams
from megatron.core import InferenceParams, mpu
from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
from megatron.core.models.gpt import GPTModel as McoreGPTModel
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.training import get_args

from swift.utils import get_logger
from .rope import dynamic_rope_update, get_rope_inv_freq
Expand Down Expand Up @@ -91,6 +92,11 @@ def __init__(
logger.warning('`apply_rope_fusion` does not support `attention_scaling`. '
f'Setting `config.apply_rope_fusion`: {config.apply_rope_fusion}')

args = get_args()
self.visual = None
if args.megatron_model_meta.visual is not None:
self.visual = args.megatron_model_meta.visual(config)

@contextmanager
def _patch_apply_rotary_pos_emb(self):
if self.attention_scaling == 1.:
Expand Down Expand Up @@ -138,10 +144,19 @@ def forward(
# Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.

# Decoder embedding.
args = get_args()
if decoder_input is not None:
pass
elif self.pre_process:
decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids)
if self.visual is not None:
if args.tensor_model_parallel_size > 1 and args.sequence_parallel:
input_ids = input_ids.chunk(
args.tensor_model_parallel_size, dim=-1)[mpu.get_tensor_model_parallel_rank()]
kwargs.update({'input_ids': input_ids})
decoder_input = decoder_input.transpose(0, 1)
decoder_input = self.visual.get_inputs_embeds(decoder_input, **kwargs)
decoder_input = decoder_input.transpose(0, 1)
else:
# intermediate stage of pipeline
# decoder will get hidden_states from encoder.input_tensor
Expand Down Expand Up @@ -172,6 +187,13 @@ def forward(
rotary_seq_len,
packed_seq=packed_seq_params is not None and packed_seq_params.qkv_format == 'thd',
)
elif self.position_embedding_type in 'mrope':
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The condition self.position_embedding_type in 'mrope' is likely a bug. It checks if self.position_embedding_type is one of the characters 'm', 'r', 'o', 'p', 'e', not the string 'mrope'. This works by coincidence for now because no other position embedding types share these characters, but it's fragile and incorrect. It should be changed to a direct string comparison for correctness and clarity.

Suggested change
elif self.position_embedding_type in 'mrope':
elif self.position_embedding_type == 'mrope':

if self.training or not self.config.flash_decode:
rotary_pos_emb = self.rotary_pos_emb(position_ids, self.mrope_section)
else:
# Flash decoding uses precomputed cos and sin for RoPE
raise NotImplementedError('Flash decoding uses precomputed cos and sin for RoPE, not implmented in '
'MultimodalRotaryEmbedding yet.')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a typo in the error message. "implmented" should be "implemented".

Suggested change
raise NotImplementedError('Flash decoding uses precomputed cos and sin for RoPE, not implmented in '
'MultimodalRotaryEmbedding yet.')
raise NotImplementedError('Flash decoding uses precomputed cos and sin for RoPE, not implemented in '
'MultimodalRotaryEmbedding yet.')

if ((self.config.enable_cuda_graph or self.config.flash_decode) and rotary_pos_cos is not None
and inference_params):
sequence_len_offset = torch.tensor(
Expand Down
15 changes: 15 additions & 0 deletions swift/megatron/model/qwen2_5_vl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from swift.llm import ModelType
from ..constant import MegatronModelType
from ..gpt import GptMegatronModelMeta
from ..register import MegatronModelMeta, register_megatron_model
from .convert import convert_hf2mcore_qwen2_5_vl, convert_mcore2hf_qwen2_5_vl
from .vit import Qwen2_5VL_Vit

register_megatron_model(
GptMegatronModelMeta(
MegatronModelType.qwen2_5_vl, [
ModelType.qwen2_5_vl,
],
convert_hf2mcore=convert_hf2mcore_qwen2_5_vl,
convert_mcore2hf=convert_mcore2hf_qwen2_5_vl,
visual=Qwen2_5VL_Vit))
31 changes: 31 additions & 0 deletions swift/megatron/model/qwen2_5_vl/convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from megatron.training import get_args

from swift.utils import deep_getattr
from ..gpt.hf2mcore import set_layer_state as set_layer_state_hf2mcore
from ..gpt.mcore2hf import set_layer_state as set_layer_state_mcore2hf


def convert_hf2mcore_qwen2_5_vl(hf_model, mg_model):
language_model = hf_model.model.language_model
args = get_args()
# language_model
mg_model.embedding.word_embeddings.weight.data.copy_(language_model.embed_tokens.weight)
if args.untie_embeddings_and_output_weights:
mg_model.output_layer.weight.data.copy_(hf_model.lm_head.weight)
mg_model.decoder.final_layernorm.weight.data.copy_(language_model.norm.weight)
for layer_idx in range(args.num_layers):
set_layer_state_hf2mcore(args, mg_model, language_model, layer_idx)
mg_model.visual.model.load_state_dict(hf_model.model.visual.state_dict())


def convert_mcore2hf_qwen2_5_vl(hf_model, mg_model):
language_model = hf_model.model.language_model
args = get_args()
# language_model
language_model.embed_tokens.weight.data.copy_(mg_model.embedding.word_embeddings.weight)
if args.untie_embeddings_and_output_weights:
hf_model.lm_head.weight.data.copy_(mg_model.output_layer.weight)
language_model.norm.weight.data.copy_(mg_model.decoder.final_layernorm.weight)
for layer_idx in range(args.num_layers):
set_layer_state_mcore2hf(args, mg_model, language_model, layer_idx)
hf_model.model.visual.load_state_dict(mg_model.visual.model.state_dict())
90 changes: 90 additions & 0 deletions swift/megatron/model/qwen2_5_vl/vit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from contextlib import contextmanager

import torch
from megatron.core.models.huggingface import HuggingFaceModule
from megatron.training import get_args

from swift.llm import get_model_tokenizer, to_device


@contextmanager
def patch_device_map_meta(model_cls):
__origin_init__ = model_cls.__init__

def __init__(self, *args, **kwargs):
with torch.device('meta'):
__origin_init__(self, *args, **kwargs)

model_cls.__init__ = __init__
try:
yield
finally:
model_cls.__init__ = __origin_init__


class Qwen2_5VL_Vit(HuggingFaceModule):

def __init__(self, config):
from transformers.models.qwen2_5_vl import Qwen2_5_VLTextModel
super().__init__(config)
args = get_args()
model_dir = args.model_info.model_dir
kwargs = {'attn_impl': 'flash_attn'} if args.attention_backend.name == 'flash' else {}
with patch_device_map_meta(Qwen2_5_VLTextModel):
model, _ = get_model_tokenizer(model_dir, args.torch_dtype, return_dummy_model=True, **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The processor is not stored after being loaded from get_model_tokenizer, but it's used later in get_inputs_embeds (e.g., self.processor.image_processor). This will raise an AttributeError at runtime. You should store the returned processor in self.processor.

Suggested change
model, _ = get_model_tokenizer(model_dir, args.torch_dtype, return_dummy_model=True, **kwargs)
model, self.processor = get_model_tokenizer(model_dir, args.torch_dtype, return_dummy_model=True, **kwargs)

self.model = model.visual.to('cuda')
self.model_config = model.config

def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)

def get_inputs_embeds(self, inputs_embeds, **kwargs):
input_ids = kwargs['input_ids']
pixel_values = kwargs.get('pixel_values')
pixel_values_videos = kwargs.get('pixel_values_videos')
image_grid_thw = kwargs.get('image_grid_thw')
video_grid_thw = kwargs.get('video_grid_thw')
dtype = self.model.dtype
if pixel_values is None and pixel_values_videos is None: # plain-text
from PIL import Image
images = [Image.new('RGB', (32, 32), (0, 0, 0))]
media_inputs = self.processor.image_processor(images=images, return_tensors='pt')
device = input_ids.device
media_inputs = to_device(media_inputs, device)
pixel_values = media_inputs['pixel_values'].type(dtype)
image_embeds = self.model(pixel_values, grid_thw=media_inputs['image_grid_thw'])
inputs_embeds += image_embeds.mean() * 0.
else:
if pixel_values is None:
pixel_values_mixed = pixel_values_videos
grid_thw = video_grid_thw
elif pixel_values_videos is None:
pixel_values_mixed = pixel_values
grid_thw = image_grid_thw
else:
pixel_values_mixed = torch.concat([pixel_values, pixel_values_videos], dim=0)
grid_thw = torch.concat([image_grid_thw, video_grid_thw], dim=0)
pixel_values_mixed = pixel_values_mixed.type(dtype)
mixed_embeds = self.model(pixel_values_mixed, grid_thw=grid_thw)
if pixel_values is None:
image_embeds = None
video_embeds = mixed_embeds
elif pixel_values_videos is None:
image_embeds = mixed_embeds
video_embeds = None
else:
merge_length = self.processor.image_processor.merge_size**2
image_tokens = (image_grid_thw.prod(dim=-1) // merge_length).sum()
image_embeds = mixed_embeds[:image_tokens]
video_embeds = mixed_embeds[image_tokens:]

if image_embeds is not None:
image_mask = (input_ids == self.model_config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds)
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)

if video_embeds is not None:
video_mask = (input_ids == self.model_config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds)
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
return inputs_embeds
3 changes: 2 additions & 1 deletion swift/megatron/model/register.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from argparse import ArgumentParser
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional, Type

import torch.nn as nn
from transformers import PretrainedConfig
Expand All @@ -20,6 +20,7 @@ class MegatronModelMeta:
convert_hf_config: Callable[[PretrainedConfig], Dict[str, Any]]
convert_mcore2hf: Callable[[nn.Module, nn.Module], None]
convert_hf2mcore: Callable[[nn.Module, nn.Module], None]
visual: Optional[Type[nn.Module]] = None

extra_args_provider: Optional[Callable[[ArgumentParser], ArgumentParser]] = None

Expand Down
Loading
Loading