gmastrapas commited on
Commit
94a1346
·
verified ·
1 Parent(s): 4586ed4

Model update

Browse files
Files changed (2) hide show
  1. README.md +465 -47
  2. processing_jvlm.py +10 -4
README.md CHANGED
@@ -275,92 +275,510 @@ Done ✅
275
  ### Using Transformers 🤗
276
 
277
  ```python
278
- from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
279
- from qwen_vl_utils import process_vision_info
280
 
281
- # default: Load the model on the available device(s)
282
- model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
283
- "Qwen/Qwen2.5-VL-3B-Instruct", torch_dtype="auto", device_map="auto"
 
284
  )
285
 
286
- # We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
287
- # model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
288
- # "Qwen/Qwen2.5-VL-3B-Instruct",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
  # torch_dtype=torch.bfloat16,
290
- # attn_implementation="flash_attention_2",
291
- # device_map="auto",
 
292
  # )
293
 
294
- # default processer
295
- processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")
296
-
297
- # The default range for the number of visual tokens per image in the model is 4-16384.
298
- # You can set min_pixels and max_pixels according to your needs, such as a token range of 256-1280, to balance performance and cost.
299
- # min_pixels = 256*28*28
300
- # max_pixels = 1280*28*28
301
- # processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)
302
-
303
- messages = [
304
  {
305
- "role": "user",
306
- "content": [
307
  {
308
- "type": "image",
309
- "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
310
  },
311
- {"type": "text", "text": "Describe this image."},
312
  ],
313
  }
314
  ]
315
 
316
- # Preparation for inference
317
- text = processor.apply_chat_template(
318
- messages, tokenize=False, add_generation_prompt=True
319
- )
320
- image_inputs, video_inputs = process_vision_info(messages)
321
  inputs = processor(
322
  text=[text],
323
- images=image_inputs,
324
- videos=video_inputs,
325
- padding=True,
326
- return_tensors="pt",
327
  )
328
- inputs = inputs.to("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
 
330
- # Inference: Generation of the output
331
- generated_ids = model.generate(**inputs, max_new_tokens=128)
332
- generated_ids_trimmed = [
333
- out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
334
- ]
335
- output_text = processor.batch_decode(
336
- generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
  )
338
- print(output_text)
 
 
 
 
 
 
 
339
  ```
340
 
341
  <details>
342
  <summary>Batch inference</summary>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
  </details>
344
 
345
  <details>
346
  <summary>Multi-image inference</summary>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
  </details>
348
 
349
  <details>
350
  <summary>Text-only inference</summary>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
  </details>
352
 
353
  <details>
354
- <summary>Mixed-batch inference</summary>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355
  </details>
356
 
357
  <details>
358
  <summary>Feature extraction</summary>
359
- </details>
360
 
361
- ### Using vLLM
 
 
362
 
363
- Coming soon!
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
 
365
 
366
  ## License
 
275
  ### Using Transformers 🤗
276
 
277
  ```python
278
+ import torch
279
+ from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
280
 
281
+ # Load the processor
282
+ # We dont currently support a fast image processor
283
+ processor = AutoProcessor.from_pretrained(
284
+ 'jinaai/jina-vlm-v1', use_fast=False, trust_remote_code=True
285
  )
286
 
287
+ # Load the model on the available device(s)
288
+ model = AutoModelForCausalLM.from_pretrained(
289
+ 'jinaai/jina-vlm-v1',
290
+ device_map='auto',
291
+ trust_remote_code=True
292
+ )
293
+
294
+ # You can specify a different model dtype and/or attention implementation
295
+ # Available attention implementations:
296
+ # 'flash_attention_2', 'sdpa', 'eager'
297
+ # Flash Attention 2 is recommended for improved inference speed and memory efficiency
298
+ # For more details, see https://github.com/Dao-AILab/flash-attention
299
+ # Flash Attention requires a CUDA device with compute capability >= 12.0
300
+ # and dtype=torch.bfloat16 or torch.float16
301
+ # SDPA and Eager are available on CPU and GPU, on all dtypes
302
+ #
303
+ # model = AutoModelForCausalLM.from_pretrained(
304
+ # 'jinaai/jina-vlm-v1',
305
  # torch_dtype=torch.bfloat16,
306
+ # attn_implementation='flash_attention_2',
307
+ # device_map='auto',
308
+ # trust_remote_code=True
309
  # )
310
 
311
+ image = './assets/the_persistence_of_memory.jpg'
312
+ conversation = [
 
 
 
 
 
 
 
 
313
  {
314
+ 'role': 'user',
315
+ 'content': [
316
  {
317
+ 'type': 'image',
318
+ 'image': image,
319
  },
320
+ {'type': 'text', 'text': 'Describe this image.'},
321
  ],
322
  }
323
  ]
324
 
325
+ text = processor.apply_chat_template(conversation, add_generation_prompt=True)
 
 
 
 
326
  inputs = processor(
327
  text=[text],
328
+ images=[image],
329
+ padding='longest',
330
+ return_tensors='pt',
 
331
  )
332
+ # Configure max_pixels and max_crops when calling the processor
333
+ # max_pixels if passed resizes all images that exceed the max number of pixels while
334
+ # preserving the aspect ratio. Less pixels == less visual tokens
335
+ # max_crops specifies the max number of crops to generate for each image, also
336
+ # reducing the number of visual tokens.
337
+ # inputs = processor(
338
+ # text=[text],
339
+ # images=[image],
340
+ # padding='longest',
341
+ # max_length=1024,
342
+ # max_crops=8,
343
+ # max_pixels=100_000,
344
+ # do_resize=True,
345
+ # return_tensors='pt',
346
+ # )
347
 
348
+ # Move the inputs to the appropriate device and/or dtype
349
+ device = torch.device('cuda')
350
+ dtype = torch.float16
351
+ model_inputs = {}
352
+ for k, v in inputs.items():
353
+ if isinstance(v, torch.Tensor):
354
+ if v.is_floating_point():
355
+ model_inputs[k] = v.to(device, dtype=dtype, non_blocking=True)
356
+ else:
357
+ model_inputs[k] = v.to(device, non_blocking=True)
358
+ else:
359
+ model_inputs[k] = v
360
+
361
+ # Inference
362
+ output = model.generate(
363
+ **model_inputs,
364
+ generation_config=GenerationConfig(
365
+ max_new_tokens=20, do_sample=False,
366
+ ),
367
+ return_dict_in_generate=True,
368
+ use_model_defaults=True,
369
  )
370
+
371
+ # Decode the output sequences and print the generated text
372
+ # Input prompts will be skipped
373
+ input_sequence_length = inputs.input_ids.shape[-1]
374
+ for idx in range(len(output.sequences)):
375
+ gen_ids = output.sequences[idx][input_sequence_length:]
376
+ response = processor.tokenizer.decode(gen_ids, skip_special_tokens=True)
377
+ print(response)
378
  ```
379
 
380
  <details>
381
  <summary>Batch inference</summary>
382
+
383
+ ```python
384
+ import torch
385
+ from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
386
+
387
+ processor = AutoProcessor.from_pretrained(
388
+ 'jinaai/jina-vlm-v1', use_fast=False, trust_remote_code=True
389
+ )
390
+ model = AutoModelForCausalLM.from_pretrained(
391
+ 'jinaai/jina-vlm-v1',
392
+ device_map='auto',
393
+ torch_dtype=torch.bfloat16,
394
+ attn_implementation='flash_attention_2',
395
+ trust_remote_code=True
396
+ )
397
+ images = [
398
+ 'https://picsum.photos/id/22/4434/3729',
399
+ 'https://picsum.photos/id/49/1280/792'
400
+ ]
401
+ conversations = [
402
+ [
403
+ {
404
+ 'role': 'user',
405
+ 'content': [
406
+ {'type': 'image', 'image': images[0]},
407
+ {'type': 'text', 'text': 'What is the man doing in this image?'},
408
+ ],
409
+ }
410
+ ],
411
+ [
412
+ {
413
+ 'role': 'user',
414
+ 'content': [
415
+ {'type': 'image', 'image': images[1]},
416
+ {'type': 'text', 'text': 'What country\'s flag is in this image?'},
417
+ ],
418
+ }
419
+ ],
420
+
421
+ ]
422
+ texts = processor.apply_chat_template(conversations, add_generation_prompt=True)
423
+ inputs = processor(
424
+ text=texts,
425
+ images=images,
426
+ padding='longest',
427
+ return_tensors='pt',
428
+ )
429
+ device = torch.device('cuda')
430
+ dtype = torch.bfloat16
431
+ model_inputs = {}
432
+ for k, v in inputs.items():
433
+ if isinstance(v, torch.Tensor):
434
+ if v.is_floating_point():
435
+ model_inputs[k] = v.to(device, dtype=dtype, non_blocking=True)
436
+ else:
437
+ model_inputs[k] = v.to(device, non_blocking=True)
438
+ else:
439
+ model_inputs[k] = v
440
+
441
+ output = model.generate(
442
+ **model_inputs,
443
+ generation_config=GenerationConfig(
444
+ max_new_tokens=20, do_sample=False,
445
+ ),
446
+ return_dict_in_generate=True,
447
+ use_model_defaults=True,
448
+ )
449
+ input_sequence_length = inputs.input_ids.shape[-1]
450
+ for idx in range(len(output.sequences)):
451
+ gen_ids = output.sequences[idx][input_sequence_length:]
452
+ response = processor.tokenizer.decode(gen_ids, skip_special_tokens=True)
453
+ print(response)
454
+ ```
455
+
456
  </details>
457
 
458
  <details>
459
  <summary>Multi-image inference</summary>
460
+
461
+ ```python
462
+ import torch
463
+ from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
464
+
465
+ processor = AutoProcessor.from_pretrained(
466
+ 'jinaai/jina-vlm-v1', use_fast=False, trust_remote_code=True
467
+ )
468
+ model = AutoModelForCausalLM.from_pretrained(
469
+ 'jinaai/jina-vlm-v1',
470
+ device_map='auto',
471
+ torch_dtype=torch.bfloat16,
472
+ attn_implementation='flash_attention_2',
473
+ trust_remote_code=True
474
+ )
475
+ images = [
476
+ 'https://picsum.photos/id/0/5000/3333',
477
+ 'https://picsum.photos/id/2/5000/3333'
478
+ ]
479
+ conversation = [
480
+ {
481
+ 'role': 'user',
482
+ 'content': [
483
+ {'type': 'image', 'image': images[0]},
484
+ {'type': 'image', 'image': images[1]},
485
+ {'type': 'text', 'text': 'What is the difference between these two images?'},
486
+ ],
487
+ }
488
+ ]
489
+ text = processor.apply_chat_template(conversation, add_generation_prompt=True)
490
+ inputs = processor(
491
+ text=[text],
492
+ images=images,
493
+ padding='longest',
494
+ return_tensors='pt',
495
+ )
496
+ device = torch.device('cuda')
497
+ dtype = torch.bfloat16
498
+ model_inputs = {}
499
+ for k, v in inputs.items():
500
+ if isinstance(v, torch.Tensor):
501
+ if v.is_floating_point():
502
+ model_inputs[k] = v.to(device, dtype=dtype, non_blocking=True)
503
+ else:
504
+ model_inputs[k] = v.to(device, non_blocking=True)
505
+ else:
506
+ model_inputs[k] = v
507
+
508
+ output = model.generate(
509
+ **model_inputs,
510
+ generation_config=GenerationConfig(
511
+ max_new_tokens=20, do_sample=False,
512
+ ),
513
+ return_dict_in_generate=True,
514
+ use_model_defaults=True,
515
+ )
516
+ input_sequence_length = inputs.input_ids.shape[-1]
517
+ for idx in range(len(output.sequences)):
518
+ gen_ids = output.sequences[idx][input_sequence_length:]
519
+ response = processor.tokenizer.decode(gen_ids, skip_special_tokens=True)
520
+ print(response)
521
+ ```
522
+
523
  </details>
524
 
525
  <details>
526
  <summary>Text-only inference</summary>
527
+
528
+ ```python
529
+ import torch
530
+ from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
531
+
532
+ processor = AutoProcessor.from_pretrained(
533
+ 'jinaai/jina-vlm-v1', use_fast=False, trust_remote_code=True
534
+ )
535
+ model = AutoModelForCausalLM.from_pretrained(
536
+ 'jinaai/jina-vlm-v1',
537
+ device_map='auto',
538
+ torch_dtype=torch.bfloat16,
539
+ attn_implementation='flash_attention_2',
540
+ trust_remote_code=True
541
+ )
542
+ conversation = [
543
+ {
544
+ 'role': 'user',
545
+ 'content': [
546
+ {
547
+ 'type': 'text',
548
+ 'text': 'Describe the concept of polymorphism in Computer Science'
549
+ },
550
+ ],
551
+ }
552
+ ]
553
+ text = processor.apply_chat_template(conversation, add_generation_prompt=True)
554
+ inputs = processor(
555
+ text=[text],
556
+ images=None,
557
+ padding='longest',
558
+ return_tensors='pt',
559
+ )
560
+ device = torch.device('cuda')
561
+ dtype = torch.bfloat16
562
+ model_inputs = {}
563
+ for k, v in inputs.items():
564
+ if isinstance(v, torch.Tensor):
565
+ if v.is_floating_point():
566
+ model_inputs[k] = v.to(device, dtype=dtype, non_blocking=True)
567
+ else:
568
+ model_inputs[k] = v.to(device, non_blocking=True)
569
+ else:
570
+ model_inputs[k] = v
571
+
572
+ output = model.generate(
573
+ **model_inputs,
574
+ generation_config=GenerationConfig(
575
+ max_new_tokens=20, do_sample=False,
576
+ ),
577
+ return_dict_in_generate=True,
578
+ use_model_defaults=True,
579
+ )
580
+ input_sequence_length = inputs.input_ids.shape[-1]
581
+ for idx in range(len(output.sequences)):
582
+ gen_ids = output.sequences[idx][input_sequence_length:]
583
+ response = processor.tokenizer.decode(gen_ids, skip_special_tokens=True)
584
+ print(response)
585
+ ```
586
+
587
  </details>
588
 
589
  <details>
590
+ <summary>Batch inference with mixed examples</summary>
591
+
592
+ ```python
593
+ import torch
594
+ from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
595
+
596
+ processor = AutoProcessor.from_pretrained(
597
+ 'jinaai/jina-vlm-v1', use_fast=False, trust_remote_code=True
598
+ )
599
+ model = AutoModelForCausalLM.from_pretrained(
600
+ 'jinaai/jina-vlm-v1',
601
+ device_map='auto',
602
+ torch_dtype=torch.bfloat16,
603
+ attn_implementation='flash_attention_2',
604
+ trust_remote_code=True
605
+ )
606
+ images = [
607
+ ['https://picsum.photos/id/22/4434/3729'],
608
+ ['https://picsum.photos/id/49/1280/792'],
609
+ [
610
+ 'https://picsum.photos/id/0/5000/3333',
611
+ 'https://picsum.photos/id/2/5000/3333',
612
+ ]
613
+ ]
614
+ conversations = [
615
+ [
616
+ {
617
+ 'role': 'user',
618
+ 'content': [
619
+ {'type': 'image', 'image': images[0][0]},
620
+ {'type': 'text', 'text': 'What is the man doing in this image?'},
621
+ ],
622
+ }
623
+ ],
624
+ [
625
+ {
626
+ 'role': 'user',
627
+ 'content': [
628
+ {'type': 'image', 'image': images[1][0]},
629
+ {'type': 'text', 'text': 'What country\'s flag is in this image?'},
630
+ ],
631
+ }
632
+ ],
633
+ [
634
+ {
635
+ 'role': 'user',
636
+ 'content': [
637
+ {'type': 'image', 'image': images[2][0]},
638
+ {'type': 'image', 'image': images[2][1]},
639
+ {'type': 'text', 'text': 'What is the difference between these two images?'},
640
+ ],
641
+ }
642
+ ],
643
+ [
644
+ {
645
+ 'role': 'user',
646
+ 'content': [
647
+ {
648
+ 'type': 'text',
649
+ 'text': 'Describe the concept of polymorphism in Computer Science'
650
+ },
651
+ ],
652
+ }
653
+ ],
654
+ ]
655
+ texts = processor.apply_chat_template(conversations, add_generation_prompt=True)
656
+ inputs = processor(
657
+ text=texts,
658
+ images=images,
659
+ padding='longest',
660
+ return_tensors='pt',
661
+ )
662
+ device = torch.device('cuda')
663
+ dtype = torch.bfloat16
664
+ model_inputs = {}
665
+ for k, v in inputs.items():
666
+ if isinstance(v, torch.Tensor):
667
+ if v.is_floating_point():
668
+ model_inputs[k] = v.to(device, dtype=dtype, non_blocking=True)
669
+ else:
670
+ model_inputs[k] = v.to(device, non_blocking=True)
671
+ else:
672
+ model_inputs[k] = v
673
+
674
+ output = model.generate(
675
+ **model_inputs,
676
+ generation_config=GenerationConfig(
677
+ max_new_tokens=20, do_sample=False,
678
+ ),
679
+ return_dict_in_generate=True,
680
+ use_model_defaults=True,
681
+ )
682
+ input_sequence_length = inputs.input_ids.shape[-1]
683
+ for idx in range(len(output.sequences)):
684
+ gen_ids = output.sequences[idx][input_sequence_length:]
685
+ response = processor.tokenizer.decode(gen_ids, skip_special_tokens=True)
686
+ print(response)
687
+ ```
688
+
689
  </details>
690
 
691
  <details>
692
  <summary>Feature extraction</summary>
 
693
 
694
+ ```python
695
+ import torch
696
+ from transformers import AutoModel, AutoProcessor
697
 
698
+ processor = AutoProcessor.from_pretrained(
699
+ 'jinaai/jina-vlm-v1', use_fast=False, trust_remote_code=True
700
+ )
701
+ model = AutoModel.from_pretrained(
702
+ 'jinaai/jina-vlm-v1',
703
+ device_map='auto',
704
+ torch_dtype=torch.bfloat16,
705
+ attn_implementation='flash_attention_2',
706
+ trust_remote_code=True
707
+ )
708
+ images = [
709
+ ['https://picsum.photos/id/22/4434/3729'],
710
+ ['https://picsum.photos/id/49/1280/792'],
711
+ [
712
+ 'https://picsum.photos/id/0/5000/3333',
713
+ 'https://picsum.photos/id/2/5000/3333',
714
+ ]
715
+ ]
716
+ conversations = [
717
+ [
718
+ {
719
+ 'role': 'user',
720
+ 'content': [
721
+ {'type': 'image', 'image': images[0][0]},
722
+ {'type': 'text', 'text': 'What is the man doing in this image?'},
723
+ ],
724
+ }
725
+ ],
726
+ [
727
+ {
728
+ 'role': 'user',
729
+ 'content': [
730
+ {'type': 'image', 'image': images[1][0]},
731
+ {'type': 'text', 'text': 'What country\'s flag is in this image?'},
732
+ ],
733
+ }
734
+ ],
735
+ [
736
+ {
737
+ 'role': 'user',
738
+ 'content': [
739
+ {'type': 'image', 'image': images[2][0]},
740
+ {'type': 'image', 'image': images[2][1]},
741
+ {'type': 'text', 'text': 'What is the difference between these two images?'},
742
+ ],
743
+ }
744
+ ],
745
+ [
746
+ {
747
+ 'role': 'user',
748
+ 'content': [
749
+ {
750
+ 'type': 'text',
751
+ 'text': 'Describe the concept of polymorphism in Computer Science'
752
+ },
753
+ ],
754
+ }
755
+ ],
756
+ ]
757
+ texts = processor.apply_chat_template(conversations, add_generation_prompt=True)
758
+ inputs = processor(
759
+ text=texts,
760
+ images=images,
761
+ padding='longest',
762
+ return_tensors='pt',
763
+ )
764
+ device = torch.device('cuda')
765
+ dtype = torch.bfloat16
766
+ model_inputs = {}
767
+ for k, v in inputs.items():
768
+ if isinstance(v, torch.Tensor):
769
+ if v.is_floating_point():
770
+ model_inputs[k] = v.to(device, dtype=dtype, non_blocking=True)
771
+ else:
772
+ model_inputs[k] = v.to(device, non_blocking=True)
773
+ else:
774
+ model_inputs[k] = v
775
+
776
+ output = model(**model_inputs)
777
+ hidden_states = output.hidden_states
778
+ last_hidden_states = output.last_hidden_state
779
+ ```
780
+
781
+ </details>
782
 
783
 
784
  ## License
processing_jvlm.py CHANGED
@@ -39,7 +39,7 @@ class JinaVLMTextKwargs(TypedDict, total=False):
39
 
40
 
41
  class JinaVLProcessingKwargs(JinaVLMTextKwargs, JinaVLMImagesKwargs, CommonKwargs):
42
- pass
43
 
44
 
45
  class JinaVLMProcessor(ProcessorMixin):
@@ -259,6 +259,7 @@ class JinaVLMProcessor(ProcessorMixin):
259
  image_tokens: List[np.ndarray],
260
  image_input_idx: List[np.ndarray],
261
  image_padding_mask: List[np.ndarray],
 
262
  add_empty_image_features: bool = False,
263
  ):
264
  """Interleave images and text tokens into multi-modal features for the model."""
@@ -282,8 +283,9 @@ class JinaVLMProcessor(ProcessorMixin):
282
  data = {
283
  'input_ids': input_ids,
284
  'position_ids': position_ids,
285
- 'labels': target_tokens,
286
  }
 
 
287
  if add_empty_image_features:
288
  # Add size-zero image features, this can be useful to make sure all
289
  # devices get an image input when the image ViT is FSDP wrapped
@@ -367,14 +369,16 @@ class JinaVLMProcessor(ProcessorMixin):
367
  image_input_idx < 0, image_input_idx, image_input_idx + 1
368
  )
369
  position_ids = np.arange(len(input_ids), dtype=np.int64)
370
- return {
371
  'input_ids': input_ids,
372
  'position_ids': position_ids,
373
  'images': images,
374
  'image_input_idx': image_input_idx,
375
  'image_masks': image_masks,
376
- 'labels': target_tokens,
377
  }
 
 
 
378
 
379
  def __call__(
380
  self,
@@ -425,6 +429,7 @@ class JinaVLMProcessor(ProcessorMixin):
425
  raise ValueError('Processor requires text input.')
426
 
427
  return_tensors = kwargs.pop('return_tensors', None)
 
428
  padding = kwargs.pop('padding', PaddingStrategy.LONGEST)
429
  padding_side = kwargs.pop('padding_side', 'left')
430
  max_length = kwargs.pop('max_length', None)
@@ -498,6 +503,7 @@ class JinaVLMProcessor(ProcessorMixin):
498
  image_input_idx,
499
  image_padding_mask if image_padding_mask is not None else [],
500
  add_empty_image_features=(batch_size > 1),
 
501
  )
502
  for k, v in output.items():
503
  outputs[k].append(v)
 
39
 
40
 
41
  class JinaVLProcessingKwargs(JinaVLMTextKwargs, JinaVLMImagesKwargs, CommonKwargs):
42
+ return_labels: Optional[bool]
43
 
44
 
45
  class JinaVLMProcessor(ProcessorMixin):
 
259
  image_tokens: List[np.ndarray],
260
  image_input_idx: List[np.ndarray],
261
  image_padding_mask: List[np.ndarray],
262
+ return_labels: bool = False,
263
  add_empty_image_features: bool = False,
264
  ):
265
  """Interleave images and text tokens into multi-modal features for the model."""
 
283
  data = {
284
  'input_ids': input_ids,
285
  'position_ids': position_ids,
 
286
  }
287
+ if return_labels:
288
+ data['labels'] = target_tokens
289
  if add_empty_image_features:
290
  # Add size-zero image features, this can be useful to make sure all
291
  # devices get an image input when the image ViT is FSDP wrapped
 
369
  image_input_idx < 0, image_input_idx, image_input_idx + 1
370
  )
371
  position_ids = np.arange(len(input_ids), dtype=np.int64)
372
+ data = {
373
  'input_ids': input_ids,
374
  'position_ids': position_ids,
375
  'images': images,
376
  'image_input_idx': image_input_idx,
377
  'image_masks': image_masks,
 
378
  }
379
+ if return_labels:
380
+ data['labels'] = target_tokens
381
+ return data
382
 
383
  def __call__(
384
  self,
 
429
  raise ValueError('Processor requires text input.')
430
 
431
  return_tensors = kwargs.pop('return_tensors', None)
432
+ return_labels = kwargs.pop('return_labels', False)
433
  padding = kwargs.pop('padding', PaddingStrategy.LONGEST)
434
  padding_side = kwargs.pop('padding_side', 'left')
435
  max_length = kwargs.pop('max_length', None)
 
503
  image_input_idx,
504
  image_padding_mask if image_padding_mask is not None else [],
505
  add_empty_image_features=(batch_size > 1),
506
+ return_labels=return_labels,
507
  )
508
  for k, v in output.items():
509
  outputs[k].append(v)