refactor-task (#18)
Browse files- refactor: rename task type to task (fb3bde88dfd9f5d35368582c3840fd30439cbbf7)
- README.md +8 -8
- custom_st.py +6 -6
- modules.json +1 -1
README.md
CHANGED
|
@@ -21546,7 +21546,7 @@ Additionally, it features 5 [LoRA](https://arxiv.org/abs/2106.09685) adapters to
|
|
| 21546 |
|
| 21547 |
### Key Features:
|
| 21548 |
- **Extended Sequence Length:** Supports up to 8192 tokens with RoPE.
|
| 21549 |
-
- **Task-Specific Embedding:** Customize embeddings through the `
|
| 21550 |
- `retrieval.query`: Used for query embeddings in asymmetric retrieval tasks
|
| 21551 |
- `retrieval.passage`: Used for passage embeddings in asymmetric retrieval tasks
|
| 21552 |
- `separation`: Used for embeddings in clustering and re-ranking applications
|
|
@@ -21605,7 +21605,7 @@ model = AutoModel.from_pretrained("jinaai/jina-embeddings-v3", trust_remote_code
|
|
| 21605 |
encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors="pt")
|
| 21606 |
|
| 21607 |
with torch.no_grad():
|
| 21608 |
-
model_output = model(**encoded_input,
|
| 21609 |
|
| 21610 |
embeddings = mean_pooling(model_output, encoded_input["attention_mask"])
|
| 21611 |
embeddings = F.normalize(embeddings, p=2, dim=1)
|
|
@@ -21643,10 +21643,10 @@ texts = [
|
|
| 21643 |
"Folge dem weißen Kaninchen.", # German
|
| 21644 |
]
|
| 21645 |
|
| 21646 |
-
# When calling the `encode` function, you can choose a `
|
| 21647 |
# 'retrieval.query', 'retrieval.passage', 'separation', 'classification', 'text-matching'
|
| 21648 |
-
# Alternatively, you can choose not to pass a `
|
| 21649 |
-
embeddings = model.encode(texts,
|
| 21650 |
|
| 21651 |
# Compute similarities
|
| 21652 |
print(embeddings[0] @ embeddings[1].T)
|
|
@@ -21680,11 +21680,11 @@ from sentence_transformers import SentenceTransformer
|
|
| 21680 |
|
| 21681 |
model = SentenceTransformer("jinaai/jina-embeddings-v3", trust_remote_code=True)
|
| 21682 |
|
| 21683 |
-
|
| 21684 |
embeddings = model.encode(
|
| 21685 |
["What is the weather like in Berlin today?"],
|
| 21686 |
-
|
| 21687 |
-
prompt_name=
|
| 21688 |
)
|
| 21689 |
```
|
| 21690 |
|
|
|
|
| 21546 |
|
| 21547 |
### Key Features:
|
| 21548 |
- **Extended Sequence Length:** Supports up to 8192 tokens with RoPE.
|
| 21549 |
+
- **Task-Specific Embedding:** Customize embeddings through the `task` argument with the following options:
|
| 21550 |
- `retrieval.query`: Used for query embeddings in asymmetric retrieval tasks
|
| 21551 |
- `retrieval.passage`: Used for passage embeddings in asymmetric retrieval tasks
|
| 21552 |
- `separation`: Used for embeddings in clustering and re-ranking applications
|
|
|
|
| 21605 |
encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors="pt")
|
| 21606 |
|
| 21607 |
with torch.no_grad():
|
| 21608 |
+
model_output = model(**encoded_input, task='retrieval.query')
|
| 21609 |
|
| 21610 |
embeddings = mean_pooling(model_output, encoded_input["attention_mask"])
|
| 21611 |
embeddings = F.normalize(embeddings, p=2, dim=1)
|
|
|
|
| 21643 |
"Folge dem weißen Kaninchen.", # German
|
| 21644 |
]
|
| 21645 |
|
| 21646 |
+
# When calling the `encode` function, you can choose a `task` based on the use case:
|
| 21647 |
# 'retrieval.query', 'retrieval.passage', 'separation', 'classification', 'text-matching'
|
| 21648 |
+
# Alternatively, you can choose not to pass a `task`, and no specific LoRA adapter will be used.
|
| 21649 |
+
embeddings = model.encode(texts, task="text-matching")
|
| 21650 |
|
| 21651 |
# Compute similarities
|
| 21652 |
print(embeddings[0] @ embeddings[1].T)
|
|
|
|
| 21680 |
|
| 21681 |
model = SentenceTransformer("jinaai/jina-embeddings-v3", trust_remote_code=True)
|
| 21682 |
|
| 21683 |
+
task = "retrieval.query"
|
| 21684 |
embeddings = model.encode(
|
| 21685 |
["What is the weather like in Berlin today?"],
|
| 21686 |
+
task=task,
|
| 21687 |
+
prompt_name=task,
|
| 21688 |
)
|
| 21689 |
```
|
| 21690 |
|
custom_st.py
CHANGED
|
@@ -91,19 +91,19 @@ class Transformer(nn.Module):
|
|
| 91 |
self.auto_model.config.tokenizer_class = self.tokenizer.__class__.__name__
|
| 92 |
|
| 93 |
def forward(
|
| 94 |
-
self, features: Dict[str, torch.Tensor],
|
| 95 |
) -> Dict[str, torch.Tensor]:
|
| 96 |
"""Returns token_embeddings, cls_token"""
|
| 97 |
-
if
|
| 98 |
raise ValueError(
|
| 99 |
-
f"Unsupported task '{
|
| 100 |
f"Supported tasks are: {', '.join(self.config.lora_adaptations)}."
|
| 101 |
-
f"Alternatively, don't pass the `
|
| 102 |
)
|
| 103 |
|
| 104 |
adapter_mask = None
|
| 105 |
-
if
|
| 106 |
-
task_id = self._adaptation_map[
|
| 107 |
num_examples = features['input_ids'].size(0)
|
| 108 |
adapter_mask = torch.full(
|
| 109 |
(num_examples,), task_id, dtype=torch.int32, device=features['input_ids'].device
|
|
|
|
| 91 |
self.auto_model.config.tokenizer_class = self.tokenizer.__class__.__name__
|
| 92 |
|
| 93 |
def forward(
|
| 94 |
+
self, features: Dict[str, torch.Tensor], task: Optional[str] = None
|
| 95 |
) -> Dict[str, torch.Tensor]:
|
| 96 |
"""Returns token_embeddings, cls_token"""
|
| 97 |
+
if task and task not in self._lora_adaptations:
|
| 98 |
raise ValueError(
|
| 99 |
+
f"Unsupported task '{task}'. "
|
| 100 |
f"Supported tasks are: {', '.join(self.config.lora_adaptations)}."
|
| 101 |
+
f"Alternatively, don't pass the `task` argument to disable LoRA."
|
| 102 |
)
|
| 103 |
|
| 104 |
adapter_mask = None
|
| 105 |
+
if task:
|
| 106 |
+
task_id = self._adaptation_map[task]
|
| 107 |
num_examples = features['input_ids'].size(0)
|
| 108 |
adapter_mask = torch.full(
|
| 109 |
(num_examples,), task_id, dtype=torch.int32, device=features['input_ids'].device
|
modules.json
CHANGED
|
@@ -4,7 +4,7 @@
|
|
| 4 |
"name": "0",
|
| 5 |
"path": "",
|
| 6 |
"type": "custom_st.Transformer",
|
| 7 |
-
"kwargs": ["
|
| 8 |
},
|
| 9 |
{
|
| 10 |
"idx": 1,
|
|
|
|
| 4 |
"name": "0",
|
| 5 |
"path": "",
|
| 6 |
"type": "custom_st.Transformer",
|
| 7 |
+
"kwargs": ["task"]
|
| 8 |
},
|
| 9 |
{
|
| 10 |
"idx": 1,
|