|
6 | 6 | import transformers |
7 | 7 | from packaging import version |
8 | 8 |
|
| 9 | +from swift.utils import get_env_args |
9 | 10 | from ..base import Template |
10 | 11 | from ..constant import MLLMTemplateType |
11 | 12 | from ..register import TemplateMeta, register_template |
@@ -307,3 +308,101 @@ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[in |
307 | 308 | )) |
308 | 309 |
|
309 | 310 | register_template(QwenTemplateMeta(MLLMTemplateType.llava_next_qwen, template_cls=LLavaTemplate)) |
| 311 | + |
| 312 | + |
| 313 | +class LLavaOneVision1_5Template(Template): |
| 314 | + image_token_id = 151655 |
| 315 | + video_token_id = 151656 |
| 316 | + placeholder_tokens = ['<|image_pad|>', '<|video_pad|>'] |
| 317 | + use_model = True |
| 318 | + support_padding_free = True |
| 319 | + |
| 320 | + def init_env_args(self): |
| 321 | + super().init_env_args() |
| 322 | + self.bbox_format = get_env_args('QWENVL_BBOX_FORMAT', str, 'legacy') |
| 323 | + |
| 324 | + def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int, |
| 325 | + inputs: StdTemplateInputs) -> List[Context]: |
| 326 | + from qwen_vl_utils import fetch_image, fetch_video |
| 327 | + assert media_type in {'image', 'video'} |
| 328 | + if media_type == 'image': |
| 329 | + inputs.images[index] = fetch_image({'image': inputs.images[index]}) |
| 330 | + if self.mode == 'lmdeploy': |
| 331 | + return ['<|vision_start|>', [-100], '<|vision_end|>'] |
| 332 | + else: |
| 333 | + return ['<|vision_start|><|image_pad|><|vision_end|>'] |
| 334 | + else: |
| 335 | + video = inputs.videos[index] |
| 336 | + video, video_kwargs = fetch_video({'video': video}, return_video_sample_fps=True) |
| 337 | + inputs.mm_processor_kwargs.setdefault('fps', []).append(video_kwargs) |
| 338 | + tokens = ['<|vision_start|><|video_pad|><|vision_end|>'] |
| 339 | + if isinstance(video, torch.Tensor): |
| 340 | + video = video.to(torch.uint8) |
| 341 | + inputs.videos[index] = video |
| 342 | + return tokens |
| 343 | + |
| 344 | + def replace_ref(self, ref: str, index: int, inputs: StdTemplateInputs) -> List[Context]: |
| 345 | + if self.bbox_format == 'legacy': |
| 346 | + return [f'<|object_ref_start|>{ref}<|object_ref_end|>'] |
| 347 | + else: |
| 348 | + return [ref] |
| 349 | + |
| 350 | + def replace_bbox(self, bbox: List[int], index: int, inputs: StdTemplateInputs) -> List[Context]: |
| 351 | + if self.bbox_format == 'legacy': |
| 352 | + return [f'<|box_start|>{self._get_bbox_str(bbox)}<|box_end|>'] |
| 353 | + else: |
| 354 | + return [str(bbox)] |
| 355 | + |
| 356 | + def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: |
| 357 | + encoded = super()._encode(inputs) |
| 358 | + processor = self.processor |
| 359 | + input_ids = encoded['input_ids'] |
| 360 | + labels = encoded['labels'] |
| 361 | + loss_scale = encoded.get('loss_scale', None) |
| 362 | + for media_type in ['images', 'videos']: |
| 363 | + mm_data = getattr(inputs, media_type) |
| 364 | + if mm_data: |
| 365 | + if media_type == 'images': |
| 366 | + media_token = self.image_token_id |
| 367 | + media_inputs = processor.image_processor(images=mm_data, return_tensors='pt', do_resize=False) |
| 368 | + media_grid_thw = media_inputs['image_grid_thw'] |
| 369 | + else: |
| 370 | + kwargs = {} |
| 371 | + if hasattr(processor, 'video_processor'): |
| 372 | + processor_func = processor.video_processor |
| 373 | + else: |
| 374 | + processor_func = processor.image_processor |
| 375 | + kwargs['images'] = None |
| 376 | + media_inputs = processor_func(videos=mm_data, return_tensors='pt', do_resize=False, **kwargs) |
| 377 | + media_grid_thw = media_inputs['video_grid_thw'] |
| 378 | + media_token = self.video_token_id |
| 379 | + idx_list = findall(input_ids, media_token) |
| 380 | + merge_length = processor.image_processor.merge_size**2 |
| 381 | + |
| 382 | + def _get_new_tokens(i): |
| 383 | + token_len = (media_grid_thw[i].prod() // merge_length) |
| 384 | + return [media_token] * token_len |
| 385 | + |
| 386 | + input_ids, labels, loss_scale = self._extend_tokens(input_ids, labels, loss_scale, idx_list, |
| 387 | + _get_new_tokens) |
| 388 | + encoded.update(media_inputs) |
| 389 | + |
| 390 | + encoded['input_ids'] = input_ids |
| 391 | + encoded['labels'] = labels |
| 392 | + encoded['loss_scale'] = loss_scale |
| 393 | + return encoded |
| 394 | + |
| 395 | + def _post_encode(self, model, inputs: Dict[str, Any]) -> Dict[str, Any]: |
| 396 | + if not self.is_training: |
| 397 | + return inputs |
| 398 | + input_ids = inputs['input_ids'] |
| 399 | + base_model = self.get_base_model(model) |
| 400 | + if hasattr(base_model.model, 'embed_tokens'): |
| 401 | + inputs_embeds = base_model.model.embed_tokens(input_ids) |
| 402 | + else: |
| 403 | + inputs_embeds = base_model.model.language_model.embed_tokens(input_ids) |
| 404 | + inputs_embeds = self._get_inputs_embeds_hf(inputs_embeds, inputs, model.visual, self.processor, model.config) |
| 405 | + return {'inputs_embeds': inputs_embeds} |
| 406 | + |
| 407 | + |
| 408 | +register_template(QwenTemplateMeta(MLLMTemplateType.llava_onevision1_5, template_cls=LLavaOneVision1_5Template)) |
0 commit comments