Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions src/transformers/models/glm4v/modeling_glm4v.py
Original file line number Diff line number Diff line change
Expand Up @@ -1053,6 +1053,7 @@ def get_rope_index(
device=input_ids.device,
)
image_index, video_index = 0, 0
video_group_index = 0
attention_mask = attention_mask.to(total_input_ids.device)
for i, input_ids in enumerate(total_input_ids):
input_ids = input_ids[attention_mask[i] == 1]
Expand Down Expand Up @@ -1082,7 +1083,6 @@ def get_rope_index(

llm_pos_ids_list = []
video_frame_num = 1

for modality_type, start_idx, end_idx in input_type_group:
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0

Expand Down Expand Up @@ -1126,7 +1126,11 @@ def get_rope_index(
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(1, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx)

video_index += 1
video_group_index += 1

if video_group_index >= video_grid_thw[video_index][0]:
video_index += 1
video_group_index = 0

video_frame_num += 1

Expand Down Expand Up @@ -1175,7 +1179,13 @@ def get_video_features(
The temporal, height and width of feature shape of each video in LLM.
"""
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
# reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames
temp_frames_hw = []
for t, h, w in video_grid_thw:
repeated_row = torch.tensor([1, h.item(), w.item()]).unsqueeze(0).repeat(t, 1)
temp_frames_hw.append(repeated_row)
flattened_video_grid_thw = torch.cat(temp_frames_hw, dim=0)
video_embeds = self.visual(pixel_values_videos, grid_thw=flattened_video_grid_thw)
split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
video_embeds = torch.split(video_embeds, split_sizes)
return video_embeds
Expand Down
8 changes: 6 additions & 2 deletions src/transformers/models/glm4v/modular_glm4v.py
Original file line number Diff line number Diff line change
Expand Up @@ -1087,6 +1087,7 @@ def get_rope_index(
device=input_ids.device,
)
image_index, video_index = 0, 0
video_group_index = 0
attention_mask = attention_mask.to(total_input_ids.device)
for i, input_ids in enumerate(total_input_ids):
input_ids = input_ids[attention_mask[i] == 1]
Expand Down Expand Up @@ -1116,7 +1117,6 @@ def get_rope_index(

llm_pos_ids_list = []
video_frame_num = 1

for modality_type, start_idx, end_idx in input_type_group:
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0

Expand Down Expand Up @@ -1160,7 +1160,11 @@ def get_rope_index(
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(1, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx)

video_index += 1
video_group_index += 1

if video_group_index >= video_grid_thw[video_index][0]:
video_index += 1
video_group_index = 0

video_frame_num += 1

Expand Down
18 changes: 12 additions & 6 deletions src/transformers/models/glm4v/processing_glm4v.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,32 +167,38 @@ def __call__(
video_index = 0
for i in range(len(text)):
while self.video_token in text[i]:
num_frames = len(video_grid_thw)
num_frames = video_grid_thw[video_index][0]
video_structure = ""

if hasattr(timestamps, "tolist"):
timestamps_list = timestamps.tolist()[0]
else:
timestamps_list = timestamps[0] if isinstance(timestamps[0], list) else timestamps

unique_timestamps = []
for idx in range(0, len(timestamps_list)):
unique_timestamps.append(timestamps_list[idx])

selected_timestamps = unique_timestamps[:num_frames]
while len(selected_timestamps) < num_frames:
selected_timestamps.append(selected_timestamps[-1] if selected_timestamps else 0)

for frame_idx in range(num_frames):
timestamp_sec = selected_timestamps[frame_idx]
frame_structure = f"<|begin_of_image|>{self.image_token}<|end_of_image|>{timestamp_sec}"
video_structure += frame_structure

text[i] = text[i].replace(self.video_token, video_structure, 1)
num_image_tokens = (
video_grid_thw[video_index].prod() // merge_length // video_grid_thw[video_index][0]
)
for frame_idx in range(num_frames):
if self.image_token in text[i]:
text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1)

video_index += 1

for frame_idx in range(len(video_grid_thw)):
if self.image_token in text[i]:
num_image_tokens = video_grid_thw[frame_idx].prod() // merge_length
text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1)
text[i] = text[i].replace("<|placeholder|>", self.image_token)

return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"])
Expand Down
4 changes: 0 additions & 4 deletions src/transformers/models/glm4v/video_processing_glm4v.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,10 +246,6 @@ def _preprocess(
processed_grids = reorder_videos(processed_grids, grouped_videos_index)
pixel_values_videos = torch.cat(processed_videos, dim=0)
video_grid_thw = torch.tensor(processed_grids)
total_frames = video_grid_thw[0][0].item()
h = video_grid_thw[0][1].item()
w = video_grid_thw[0][2].item()
video_grid_thw = [[1, h, w] for _ in range(total_frames)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we also would need to pad timestamps as otherwise it will fail when different number of frames are sampled per video. We've been discussing it internally with @zRzRzRzRzRzRzR , not sure though if he has any PR yet

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, timestamps is not good to return here, can we return it like qwen2_5vl does ?

if isinstance(fps, (int, float)):
second_per_grid_ts = [self.video_processor.temporal_patch_size / fps] * len(video_grid_thw)
elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw):
second_per_grid_ts = [self.video_processor.temporal_patch_size / tmp for tmp in fps]
else:
raise ValueError(
f"The length of fps ({len(fps) if hasattr(fps, '__len__') else fps}) must be equal to the length of video_grid_thw ({len(video_grid_thw)}) or fps should be a single number."
)
videos_inputs.update({"second_per_grid_ts": second_per_grid_ts})

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, not sure if this is equivalent to what GLM4V does because in GLM we want to add timestamps per frame in the prompt. We talked with this internally and decided that padding/unpadding can work, as the timestamps are used in internal processing only. So we can pad on the right, and strip off pad values in processing.py

data = {
"pixel_values_videos": pixel_values_videos,
"video_grid_thw": video_grid_thw,
Expand Down