Skip to content

fix Glm4v batch videos forward #39172

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions src/transformers/models/glm4v/modeling_glm4v.py
Original file line number Diff line number Diff line change
Expand Up @@ -1053,6 +1053,7 @@ def get_rope_index(
device=input_ids.device,
)
image_index, video_index = 0, 0
video_group_index = 0
attention_mask = attention_mask.to(total_input_ids.device)
for i, input_ids in enumerate(total_input_ids):
input_ids = input_ids[attention_mask[i] == 1]
Expand Down Expand Up @@ -1082,7 +1083,6 @@ def get_rope_index(

llm_pos_ids_list = []
video_frame_num = 1

for modality_type, start_idx, end_idx in input_type_group:
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0

Expand Down Expand Up @@ -1126,7 +1126,11 @@ def get_rope_index(
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(1, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx)

video_index += 1
video_group_index += 1

if video_group_index >= video_grid_thw[video_index][0]:
video_index += 1
video_group_index = 0

video_frame_num += 1

Expand Down Expand Up @@ -1175,7 +1179,13 @@ def get_video_features(
The temporal, height and width of feature shape of each video in LLM.
"""
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
# reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames
temp_frames_hw = []
for t, h, w in video_grid_thw:
repeated_row = torch.tensor([1, h.item(), w.item()]).unsqueeze(0).repeat(t, 1)
temp_frames_hw.append(repeated_row)
flattened_video_grid_thw = torch.cat(temp_frames_hw, dim=0)
video_embeds = self.visual(pixel_values_videos, grid_thw=flattened_video_grid_thw)
Comment on lines +1182 to +1188
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, prob because this is just copied from Qwen2-VL when running modular. To actually fix it, we need to overwrite get_video_features in modular_glm4v.py instead of inheriting from Qwen2-VL

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I have overwritten this function in the modular_glm4v.py

split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
video_embeds = torch.split(video_embeds, split_sizes)
return video_embeds
Expand Down
50 changes: 42 additions & 8 deletions src/transformers/models/glm4v/modular_glm4v.py
Original file line number Diff line number Diff line change
Expand Up @@ -1087,6 +1087,7 @@ def get_rope_index(
device=input_ids.device,
)
image_index, video_index = 0, 0
video_group_index = 0
attention_mask = attention_mask.to(total_input_ids.device)
for i, input_ids in enumerate(total_input_ids):
input_ids = input_ids[attention_mask[i] == 1]
Expand Down Expand Up @@ -1116,7 +1117,6 @@ def get_rope_index(

llm_pos_ids_list = []
video_frame_num = 1

for modality_type, start_idx, end_idx in input_type_group:
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0

Expand Down Expand Up @@ -1160,7 +1160,11 @@ def get_rope_index(
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(1, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx)

video_index += 1
video_group_index += 1

if video_group_index >= video_grid_thw[video_index][0]:
video_index += 1
video_group_index = 0

video_frame_num += 1

Expand Down Expand Up @@ -1196,6 +1200,30 @@ def get_rope_index(

return position_ids, mrope_position_deltas

def get_video_features(
self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None
):
"""
Encodes videos into continuous embeddings that can be forwarded to the language model.

Args:
pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
The tensors corresponding to the input videos.
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
The temporal, height and width of feature shape of each video in LLM.
"""
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
# reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames
temp_frames_hw = []
for t, h, w in video_grid_thw:
repeated_row = torch.tensor([1, h.item(), w.item()]).unsqueeze(0).repeat(t, 1)
temp_frames_hw.append(repeated_row)
flattened_video_grid_thw = torch.cat(temp_frames_hw, dim=0)
video_embeds = self.visual(pixel_values_videos, grid_thw=flattened_video_grid_thw)
split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
video_embeds = torch.split(video_embeds, split_sizes)
return video_embeds

@auto_docstring
@can_return_tuple
def forward(
Expand Down Expand Up @@ -1687,32 +1715,38 @@ def __call__(
video_index = 0
for i in range(len(text)):
while self.video_token in text[i]:
num_frames = len(video_grid_thw)
num_frames = video_grid_thw[video_index][0]
video_structure = ""

if hasattr(timestamps, "tolist"):
timestamps_list = timestamps.tolist()[0]
else:
timestamps_list = timestamps[0] if isinstance(timestamps[0], list) else timestamps

unique_timestamps = []
for idx in range(0, len(timestamps_list)):
unique_timestamps.append(timestamps_list[idx])

selected_timestamps = unique_timestamps[:num_frames]
while len(selected_timestamps) < num_frames:
selected_timestamps.append(selected_timestamps[-1] if selected_timestamps else 0)

for frame_idx in range(num_frames):
timestamp_sec = selected_timestamps[frame_idx]
frame_structure = f"<|begin_of_image|>{self.image_token}<|end_of_image|>{timestamp_sec}"
video_structure += frame_structure

text[i] = text[i].replace(self.video_token, video_structure, 1)
num_image_tokens = (
video_grid_thw[video_index].prod() // merge_length // video_grid_thw[video_index][0]
)
for frame_idx in range(num_frames):
if self.image_token in text[i]:
text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1)

video_index += 1

for frame_idx in range(len(video_grid_thw)):
if self.image_token in text[i]:
num_image_tokens = video_grid_thw[frame_idx].prod() // merge_length
text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1)
text[i] = text[i].replace("<|placeholder|>", self.image_token)

return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"])
Expand Down
18 changes: 12 additions & 6 deletions src/transformers/models/glm4v/processing_glm4v.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,32 +167,38 @@ def __call__(
video_index = 0
for i in range(len(text)):
while self.video_token in text[i]:
num_frames = len(video_grid_thw)
num_frames = video_grid_thw[video_index][0]
video_structure = ""

if hasattr(timestamps, "tolist"):
timestamps_list = timestamps.tolist()[0]
else:
timestamps_list = timestamps[0] if isinstance(timestamps[0], list) else timestamps

unique_timestamps = []
for idx in range(0, len(timestamps_list)):
unique_timestamps.append(timestamps_list[idx])

selected_timestamps = unique_timestamps[:num_frames]
while len(selected_timestamps) < num_frames:
selected_timestamps.append(selected_timestamps[-1] if selected_timestamps else 0)

for frame_idx in range(num_frames):
timestamp_sec = selected_timestamps[frame_idx]
frame_structure = f"<|begin_of_image|>{self.image_token}<|end_of_image|>{timestamp_sec}"
video_structure += frame_structure

text[i] = text[i].replace(self.video_token, video_structure, 1)
num_image_tokens = (
video_grid_thw[video_index].prod() // merge_length // video_grid_thw[video_index][0]
)
for frame_idx in range(num_frames):
if self.image_token in text[i]:
text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1)

Comment on lines +192 to +198
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect, this has been itching me since release ❤️ I agree this works when equal amount of frames are sampled per video

video_index += 1

for frame_idx in range(len(video_grid_thw)):
if self.image_token in text[i]:
num_image_tokens = video_grid_thw[frame_idx].prod() // merge_length
text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1)
text[i] = text[i].replace("<|placeholder|>", self.image_token)

return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"])
Expand Down
4 changes: 0 additions & 4 deletions src/transformers/models/glm4v/video_processing_glm4v.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,10 +246,6 @@ def _preprocess(
processed_grids = reorder_videos(processed_grids, grouped_videos_index)
pixel_values_videos = torch.cat(processed_videos, dim=0)
video_grid_thw = torch.tensor(processed_grids)
total_frames = video_grid_thw[0][0].item()
h = video_grid_thw[0][1].item()
w = video_grid_thw[0][2].item()
video_grid_thw = [[1, h, w] for _ in range(total_frames)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we also would need to pad timestamps as otherwise it will fail when different number of frames are sampled per video. We've been discussing it internally with @zRzRzRzRzRzRzR , not sure though if he has any PR yet

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, timestamps is not good to return here, can we return it like qwen2_5vl does ?

if isinstance(fps, (int, float)):
second_per_grid_ts = [self.video_processor.temporal_patch_size / fps] * len(video_grid_thw)
elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw):
second_per_grid_ts = [self.video_processor.temporal_patch_size / tmp for tmp in fps]
else:
raise ValueError(
f"The length of fps ({len(fps) if hasattr(fps, '__len__') else fps}) must be equal to the length of video_grid_thw ({len(video_grid_thw)}) or fps should be a single number."
)
videos_inputs.update({"second_per_grid_ts": second_per_grid_ts})

data = {
"pixel_values_videos": pixel_values_videos,
"video_grid_thw": video_grid_thw,
Expand Down