-
Notifications
You must be signed in to change notification settings - Fork 29.5k
fix Glm4v batch videos forward #39172
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
adc82c8
807af61
b729471
5df3828
454b4a3
280e506
2ad2ea2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -167,32 +167,38 @@ def __call__( | |
video_index = 0 | ||
for i in range(len(text)): | ||
while self.video_token in text[i]: | ||
num_frames = len(video_grid_thw) | ||
num_frames = video_grid_thw[video_index][0] | ||
video_structure = "" | ||
|
||
if hasattr(timestamps, "tolist"): | ||
timestamps_list = timestamps.tolist()[0] | ||
else: | ||
timestamps_list = timestamps[0] if isinstance(timestamps[0], list) else timestamps | ||
|
||
unique_timestamps = [] | ||
for idx in range(0, len(timestamps_list)): | ||
unique_timestamps.append(timestamps_list[idx]) | ||
|
||
selected_timestamps = unique_timestamps[:num_frames] | ||
while len(selected_timestamps) < num_frames: | ||
selected_timestamps.append(selected_timestamps[-1] if selected_timestamps else 0) | ||
|
||
for frame_idx in range(num_frames): | ||
timestamp_sec = selected_timestamps[frame_idx] | ||
frame_structure = f"<|begin_of_image|>{self.image_token}<|end_of_image|>{timestamp_sec}" | ||
video_structure += frame_structure | ||
|
||
text[i] = text[i].replace(self.video_token, video_structure, 1) | ||
num_image_tokens = ( | ||
video_grid_thw[video_index].prod() // merge_length // video_grid_thw[video_index][0] | ||
) | ||
for frame_idx in range(num_frames): | ||
if self.image_token in text[i]: | ||
text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1) | ||
|
||
Comment on lines
+192
to
+198
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perfect, this has been itching me since release ❤️ I agree this works when equal amount of frames are sampled per video |
||
video_index += 1 | ||
|
||
for frame_idx in range(len(video_grid_thw)): | ||
if self.image_token in text[i]: | ||
num_image_tokens = video_grid_thw[frame_idx].prod() // merge_length | ||
text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1) | ||
text[i] = text[i].replace("<|placeholder|>", self.image_token) | ||
|
||
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) | ||
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) | ||
self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"]) | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -246,10 +246,6 @@ def _preprocess( | |||||||||||||||||||
processed_grids = reorder_videos(processed_grids, grouped_videos_index) | ||||||||||||||||||||
pixel_values_videos = torch.cat(processed_videos, dim=0) | ||||||||||||||||||||
video_grid_thw = torch.tensor(processed_grids) | ||||||||||||||||||||
total_frames = video_grid_thw[0][0].item() | ||||||||||||||||||||
h = video_grid_thw[0][1].item() | ||||||||||||||||||||
w = video_grid_thw[0][2].item() | ||||||||||||||||||||
video_grid_thw = [[1, h, w] for _ in range(total_frames)] | ||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we also would need to pad There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, transformers/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py Lines 158 to 166 in df12d87
|
||||||||||||||||||||
data = { | ||||||||||||||||||||
"pixel_values_videos": pixel_values_videos, | ||||||||||||||||||||
"video_grid_thw": video_grid_thw, | ||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, prob because this is just copied from Qwen2-VL when running
modular
. To actually fix it, we need to overwriteget_video_features
inmodular_glm4v.py
instead of inheriting from Qwen2-VLThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! I have overwritten this function in the
modular_glm4v.py