Skip to content

fix Glm4v batch videos forward #39172

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open

Conversation

Kuangdd01
Copy link
Contributor

@Kuangdd01 Kuangdd01 commented Jul 2, 2025

What does this PR do?

Fixes the issues of video_processing and get_video_features for GLM4V.

Have tested with following scripts

import torch
from transformers import AutoProcessor, Glm4vForConditionalGeneration
from PIL import Image
import numpy as np
import cv2
import os
from dataclasses import dataclass
from transformers.video_utils import VideoMetadata

def prepare_video_metadata(videos):
    video_metadata = []
    for video in videos:
        if isinstance(video, list):
            num_frames = len(video)
        elif hasattr(video, "shape"):
            if len(video.shape) == 4:  # (T, H, W, C)
                num_frames = video.shape[0]
            else:
                num_frames = 1
        else:
            num_frames = 8
            print("eeeeee")

        metadata = {
            "fps": 2,
            "duration": num_frames / 2,
            "total_frames": num_frames,
        }
        video_metadata.append(metadata)
    return video_metadata

def test_video_processing(video_path_list, num_frames=4):
    selected_frames = []
    for video_path in video_path_list:
        cap = cv2.VideoCapture(video_path)
        frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        print(f"Total frames: {frame_count}")

    video_metadata = []
    for video_path in video_path_list:
        temp_frames = []
        cap = cv2.VideoCapture(video_path)
        frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        step = max(frame_count // num_frames, 1)
        for i in range(0, frame_count, step):
            cap.set(cv2.CAP_PROP_POS_FRAMES, i)
            ret, frame = cap.read()
            if not ret:
                continue
            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            pil_img = Image.fromarray(frame_rgb)
            temp_frames.append(pil_img)
        selected_frames.append(temp_frames)

    video_metadata = prepare_video_metadata(selected_frames)
    video_inputs = processor.video_processor(videos=selected_frames, video_metadata=video_metadata)

    questions = ["What kind of dog is this?", "Describe the background."]

    messages_batch = [
        [
            {
                "role": "user",
                "content": [
                    {"type": "video"},
                    {"type": "text", "text": question},
                ],
            }
        ]
        for question in questions
    ]

    texts = [
        processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True)
        for msg in messages_batch
    ]

    inputs_batch = processor(text=texts, videos=selected_frames, video_metadata=video_metadata, return_tensors="pt", padding=True)

    print(processor.batch_decode(inputs_batch['input_ids'])[0])
    rope_pos, deltas = model.model.get_rope_index(
        inputs_batch["input_ids"],
        None,
        inputs_batch["video_grid_thw"],
        inputs_batch["attention_mask"]
    )

    print(rope_pos.shape, "\n", deltas)

processor_name = "THUDM/GLM-4.1V-9B-Thinking"
processor = AutoProcessor.from_pretrained(processor_name)
model = Glm4vForConditionalGeneration.from_pretrained(processor_name)

if __name__ == "__main__":
    # image_path = "./data/mllm_demo_data/1.jpg"
    video_path_1 = "./data/mllm_demo_data/1.mp4"
    video_path_2 = "./data/mllm_demo_data/2.avi"

    test_video_processing([video_path_1, video_path_2])

For forward logits checking, @zRzRzRzRzRzRzR

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@zucchini-nlp cc @zRzRzRzRzRzRzR

@Kuangdd01
Copy link
Contributor Author

Failed for changing the get_video_features which is not consistent with that generated from modular. 😂

total_frames = video_grid_thw[0][0].item()
h = video_grid_thw[0][1].item()
w = video_grid_thw[0][2].item()
video_grid_thw = [[1, h, w] for _ in range(total_frames)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we also would need to pad timestamps as otherwise it will fail when different number of frames are sampled per video. We've been discussing it internally with @zRzRzRzRzRzRzR , not sure though if he has any PR yet

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, timestamps is not good to return here, can we return it like qwen2_5vl does ?

if isinstance(fps, (int, float)):
second_per_grid_ts = [self.video_processor.temporal_patch_size / fps] * len(video_grid_thw)
elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw):
second_per_grid_ts = [self.video_processor.temporal_patch_size / tmp for tmp in fps]
else:
raise ValueError(
f"The length of fps ({len(fps) if hasattr(fps, '__len__') else fps}) must be equal to the length of video_grid_thw ({len(video_grid_thw)}) or fps should be a single number."
)
videos_inputs.update({"second_per_grid_ts": second_per_grid_ts})

Comment on lines +192 to +198
num_image_tokens = (
video_grid_thw[video_index].prod() // merge_length // video_grid_thw[video_index][0]
)
for frame_idx in range(num_frames):
if self.image_token in text[i]:
text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect, this has been itching me since release ❤️ I agree this works when equal amount of frames are sampled per video

Comment on lines +1182 to +1188
# reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames
temp_frames_hw = []
for t, h, w in video_grid_thw:
repeated_row = torch.tensor([1, h.item(), w.item()]).unsqueeze(0).repeat(t, 1)
temp_frames_hw.append(repeated_row)
flattened_video_grid_thw = torch.cat(temp_frames_hw, dim=0)
video_embeds = self.visual(pixel_values_videos, grid_thw=flattened_video_grid_thw)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, prob because this is just copied from Qwen2-VL when running modular. To actually fix it, we need to overwrite get_video_features in modular_glm4v.py instead of inheriting from Qwen2-VL

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I have overwritten this function in the modular_glm4v.py

Copy link
Contributor

github-actions bot commented Jul 2, 2025

[For maintainers] Suggested jobs to run (before merge)

run-slow: glm4v

@Kuangdd01
Copy link
Contributor Author

😀Do I need to write more unit tests for this change?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants