feat-rename-vector-type-0622

#21
Files changed (1) hide show
  1. modeling_jina_embeddings_v4.py +38 -35
modeling_jina_embeddings_v4.py CHANGED
@@ -31,7 +31,6 @@ class PromptType(str, Enum):
31
 
32
 
33
  PREFIX_DICT = {"query": "Query", "passage": "Passage"}
34
- VECTOR_TYPES = ["single_vector", "multi_vector"]
35
 
36
 
37
  class JinaEmbeddingsV4Processor(Qwen2_5_VLProcessor):
@@ -284,8 +283,9 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
284
  attention_mask (torch.Tensor): The attention mask tensor.
285
  Returns:
286
  JinaEmbeddingsV4ModelOutput:
287
- single_vector (torch.Tensor): Single-vector embeddings of shape (batch_size, dim).
288
- multi_vector (torch.Tensor): Multi-vector embeddings of shape (batch_size, num_tokens, dim).
 
289
  """
290
  # Forward pass through the VLM
291
  hidden_states = self.get_last_hidden_states(
@@ -320,7 +320,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
320
  task_label: Union[str, List[str]],
321
  processor_fn: Callable,
322
  desc: str,
323
- vector_type: str = "single_vector",
324
  return_numpy: bool = False,
325
  batch_size: int = 32,
326
  truncate_dim: Optional[int] = None,
@@ -340,7 +340,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
340
  device_type=torch.device(self.device).type, dtype=torch.bfloat16
341
  ):
342
  embeddings = self(**batch, task_label=task_label)
343
- if vector_type == "single_vector":
344
  embeddings = embeddings.single_vec_emb
345
  if truncate_dim is not None:
346
  embeddings = embeddings[:, :truncate_dim]
@@ -357,7 +357,6 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
357
 
358
  def _validate_encoding_params(
359
  self,
360
- vector_type: Optional[str] = None,
361
  truncate_dim: Optional[int] = None,
362
  prompt_name: Optional[str] = None,
363
  ) -> Dict[str, Any]:
@@ -374,14 +373,6 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
374
  else PREFIX_DICT["query"]
375
  )
376
 
377
- vector_type = vector_type or "single_vector"
378
- if vector_type not in VECTOR_TYPES:
379
- raise ValueError(
380
- f"Invalid vector_type: {vector_type}. Must be one of {VECTOR_TYPES}."
381
- )
382
- else:
383
- encode_kwargs["vector_type"] = vector_type
384
-
385
  truncate_dim = truncate_dim or self.config.truncate_dim
386
  if truncate_dim is not None and truncate_dim not in self.config.matryoshka_dims:
387
  raise ValueError(
@@ -407,36 +398,34 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
407
  )
408
  return task
409
 
410
- def encode_texts(
411
  self,
412
- texts: List[str],
413
  task: Optional[str] = None,
414
  max_length: int = 8192,
415
  batch_size: int = 8,
416
- vector_type: Optional[str] = None,
417
  return_numpy: bool = False,
418
  truncate_dim: Optional[int] = None,
419
  prompt_name: Optional[str] = None,
420
- ) -> List[torch.Tensor]:
421
  """
422
  Encodes a list of texts into embeddings.
423
 
424
  Args:
425
- texts: List of text strings to encode
426
  max_length: Maximum token length for text processing
427
  batch_size: Number of texts to process at once
428
- vector_type: Type of embedding vector to generate ('single_vector' or 'multi_vector')
429
  return_numpy: Whether to return numpy arrays instead of torch tensors
430
  truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
431
  prompt_name: Type of text being encoded ('query' or 'passage')
432
 
433
  Returns:
434
- List of text embeddings as tensors or numpy arrays
435
  """
436
  prompt_name = prompt_name or "query"
437
- encode_kwargs = self._validate_encoding_params(
438
- vector_type, truncate_dim, prompt_name
439
- )
440
 
441
  task = self._validate_task(task)
442
 
@@ -446,17 +435,23 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
446
  prefix=encode_kwargs.pop("prefix"),
447
  )
448
 
 
 
 
 
 
449
  embeddings = self._process_batches(
450
  data=texts,
451
  processor_fn=processor_fn,
452
  desc="Encoding texts...",
453
  task_label=task,
 
454
  return_numpy=return_numpy,
455
  batch_size=batch_size,
456
  **encode_kwargs,
457
  )
458
 
459
- return embeddings
460
 
461
  def _load_images_if_needed(
462
  self, images: List[Union[str, Image.Image]]
@@ -472,37 +467,44 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
472
  loaded_images.append(image)
473
  return loaded_images
474
 
475
- def encode_images(
476
  self,
477
- images: List[Union[str, Image.Image]],
478
  task: Optional[str] = None,
479
  batch_size: int = 8,
480
- vector_type: Optional[str] = None,
481
  return_numpy: bool = False,
482
  truncate_dim: Optional[int] = None,
483
  max_pixels: Optional[int] = None,
484
- ) -> List[torch.Tensor]:
485
  """
486
- Encodes a list of images into embeddings.
487
 
488
  Args:
489
- images: List of PIL images, URLs, or local file paths to encode
490
  batch_size: Number of images to process at once
491
- vector_type: Type of embedding vector to generate ('single_vector' or 'multi_vector')
492
  return_numpy: Whether to return numpy arrays instead of torch tensors
493
  truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
494
  max_pixels: Maximum number of pixels to process per image
495
 
496
  Returns:
497
- List of image embeddings as tensors or numpy arrays
498
  """
499
  if max_pixels:
500
  default_max_pixels = self.processor.image_processor.max_pixels
501
  self.processor.image_processor.max_pixels = (
502
  max_pixels # change during encoding
503
  )
504
- encode_kwargs = self._validate_encoding_params(vector_type, truncate_dim)
505
  task = self._validate_task(task)
 
 
 
 
 
 
 
506
  images = self._load_images_if_needed(images)
507
  embeddings = self._process_batches(
508
  data=images,
@@ -510,6 +512,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
510
  desc="Encoding images...",
511
  task_label=task,
512
  batch_size=batch_size,
 
513
  return_numpy=return_numpy,
514
  **encode_kwargs,
515
  )
@@ -517,7 +520,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
517
  if max_pixels:
518
  self.processor.image_processor.max_pixels = default_max_pixels
519
 
520
- return embeddings
521
 
522
  @classmethod
523
  def from_pretrained(
 
31
 
32
 
33
  PREFIX_DICT = {"query": "Query", "passage": "Passage"}
 
34
 
35
 
36
  class JinaEmbeddingsV4Processor(Qwen2_5_VLProcessor):
 
283
  attention_mask (torch.Tensor): The attention mask tensor.
284
  Returns:
285
  JinaEmbeddingsV4ModelOutput:
286
+ vlm_last_hidden_states (torch.Tensor, optional): Last hidden states of the VLM.
287
+ single_vec_emb (torch.Tensor, optional): Single-vector embeddings.
288
+ multi_vec_emb (torch.Tensor, optional): Multi-vector embeddings.
289
  """
290
  # Forward pass through the VLM
291
  hidden_states = self.get_last_hidden_states(
 
320
  task_label: Union[str, List[str]],
321
  processor_fn: Callable,
322
  desc: str,
323
+ return_multivector: bool = False,
324
  return_numpy: bool = False,
325
  batch_size: int = 32,
326
  truncate_dim: Optional[int] = None,
 
340
  device_type=torch.device(self.device).type, dtype=torch.bfloat16
341
  ):
342
  embeddings = self(**batch, task_label=task_label)
343
+ if not return_multivector:
344
  embeddings = embeddings.single_vec_emb
345
  if truncate_dim is not None:
346
  embeddings = embeddings[:, :truncate_dim]
 
357
 
358
  def _validate_encoding_params(
359
  self,
 
360
  truncate_dim: Optional[int] = None,
361
  prompt_name: Optional[str] = None,
362
  ) -> Dict[str, Any]:
 
373
  else PREFIX_DICT["query"]
374
  )
375
 
 
 
 
 
 
 
 
 
376
  truncate_dim = truncate_dim or self.config.truncate_dim
377
  if truncate_dim is not None and truncate_dim not in self.config.matryoshka_dims:
378
  raise ValueError(
 
398
  )
399
  return task
400
 
401
+ def encode_text(
402
  self,
403
+ texts: Union[str, List[str]],
404
  task: Optional[str] = None,
405
  max_length: int = 8192,
406
  batch_size: int = 8,
407
+ return_multivector: bool = False,
408
  return_numpy: bool = False,
409
  truncate_dim: Optional[int] = None,
410
  prompt_name: Optional[str] = None,
411
+ ) -> Union[List[torch.Tensor], torch.Tensor]:
412
  """
413
  Encodes a list of texts into embeddings.
414
 
415
  Args:
416
+ texts: text or list of text strings to encode
417
  max_length: Maximum token length for text processing
418
  batch_size: Number of texts to process at once
419
+ return_multivector: Whether to return multi-vector embeddings instead of single-vector embeddings
420
  return_numpy: Whether to return numpy arrays instead of torch tensors
421
  truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
422
  prompt_name: Type of text being encoded ('query' or 'passage')
423
 
424
  Returns:
425
+ List of text embeddings as tensors or numpy arrays when encoding multiple texts, or single text embedding as tensor when encoding a single text
426
  """
427
  prompt_name = prompt_name or "query"
428
+ encode_kwargs = self._validate_encoding_params(truncate_dim=truncate_dim, prompt_name=prompt_name)
 
 
429
 
430
  task = self._validate_task(task)
431
 
 
435
  prefix=encode_kwargs.pop("prefix"),
436
  )
437
 
438
+ return_list = isinstance(texts, list)
439
+
440
+ if isinstance(texts, str):
441
+ texts = [texts]
442
+
443
  embeddings = self._process_batches(
444
  data=texts,
445
  processor_fn=processor_fn,
446
  desc="Encoding texts...",
447
  task_label=task,
448
+ return_multivector=return_multivector,
449
  return_numpy=return_numpy,
450
  batch_size=batch_size,
451
  **encode_kwargs,
452
  )
453
 
454
+ return embeddings if return_list else embeddings[0]
455
 
456
  def _load_images_if_needed(
457
  self, images: List[Union[str, Image.Image]]
 
467
  loaded_images.append(image)
468
  return loaded_images
469
 
470
+ def encode_image(
471
  self,
472
+ images: Union[str, Image.Image, List[Union[str, Image.Image]]],
473
  task: Optional[str] = None,
474
  batch_size: int = 8,
475
+ return_multivector: bool = False,
476
  return_numpy: bool = False,
477
  truncate_dim: Optional[int] = None,
478
  max_pixels: Optional[int] = None,
479
+ ) -> Union[List[torch.Tensor], torch.Tensor]:
480
  """
481
+ Encodes a list of images or a single image into embedding(s).
482
 
483
  Args:
484
+ images: image(s) to encode, can be PIL Image(s), URL(s), or local file path(s)
485
  batch_size: Number of images to process at once
486
+ return_multivector: Whether to return multi-vector embeddings instead of single-vector embeddings
487
  return_numpy: Whether to return numpy arrays instead of torch tensors
488
  truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
489
  max_pixels: Maximum number of pixels to process per image
490
 
491
  Returns:
492
+ List of image embeddings as tensors or numpy arrays when encoding multiple images, or single image embedding as tensor when encoding a single image
493
  """
494
  if max_pixels:
495
  default_max_pixels = self.processor.image_processor.max_pixels
496
  self.processor.image_processor.max_pixels = (
497
  max_pixels # change during encoding
498
  )
499
+ encode_kwargs = self._validate_encoding_params(truncate_dim=truncate_dim)
500
  task = self._validate_task(task)
501
+
502
+ return_list = isinstance(images, list)
503
+
504
+ # Convert single image to list
505
+ if isinstance(images, (str, Image.Image)):
506
+ images = [images]
507
+
508
  images = self._load_images_if_needed(images)
509
  embeddings = self._process_batches(
510
  data=images,
 
512
  desc="Encoding images...",
513
  task_label=task,
514
  batch_size=batch_size,
515
+ return_multivector=return_multivector,
516
  return_numpy=return_numpy,
517
  **encode_kwargs,
518
  )
 
520
  if max_pixels:
521
  self.processor.image_processor.max_pixels = default_max_pixels
522
 
523
+ return embeddings if return_list else embeddings[0]
524
 
525
  @classmethod
526
  def from_pretrained(