|  | import random | 
					
						
						|  | import logging | 
					
						
						|  | from datasets import load_dataset, Dataset | 
					
						
						|  | from sentence_transformers import ( | 
					
						
						|  | SentenceTransformer, | 
					
						
						|  | SentenceTransformerTrainer, | 
					
						
						|  | SentenceTransformerTrainingArguments, | 
					
						
						|  | SentenceTransformerModelCardData, | 
					
						
						|  | ) | 
					
						
						|  | from sentence_transformers.losses import MatryoshkaLoss, MultipleNegativesRankingLoss | 
					
						
						|  | from sentence_transformers.training_args import BatchSamplers | 
					
						
						|  | from sentence_transformers.evaluation import InformationRetrievalEvaluator, SequentialEvaluator | 
					
						
						|  | from sentence_transformers.models.StaticEmbedding import StaticEmbedding | 
					
						
						|  |  | 
					
						
						|  | from transformers import AutoTokenizer | 
					
						
						|  |  | 
					
						
						|  | from sentence_transformers.util import cos_sim | 
					
						
						|  | logging.basicConfig( | 
					
						
						|  | format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | static_embedding = StaticEmbedding(AutoTokenizer.from_pretrained("BEE-spoke-data/wordpiece-tokenizer-32k-en_code-msp"), embedding_dim=1024) | 
					
						
						|  | model = SentenceTransformer( | 
					
						
						|  | modules=[static_embedding], | 
					
						
						|  | model_card_data=SentenceTransformerModelCardData( | 
					
						
						|  | language="en", | 
					
						
						|  | license="apache-2.0", | 
					
						
						|  | model_name="Static Embeddings with BEE-spoke-data/wordpiece-tokenizer-32k-en_code-msp tokenizer finetuned on GooAQ pairs", | 
					
						
						|  | ), | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | dataset = load_dataset("sentence-transformers/gooaq", split="train") | 
					
						
						|  | dataset = dataset.add_column("id", range(len(dataset))) | 
					
						
						|  | dataset_dict = dataset.train_test_split(test_size=10_000, seed=12) | 
					
						
						|  | train_dataset: Dataset = dataset_dict["train"] | 
					
						
						|  | eval_dataset: Dataset = dataset_dict["test"] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | loss = MultipleNegativesRankingLoss(model) | 
					
						
						|  | loss = MatryoshkaLoss(model, loss, matryoshka_dims=[32, 64, 128, 256, 512, 1024]) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | run_name = "static-BEE-spoke-data-tokenizer-v2-gooaq" | 
					
						
						|  | args = SentenceTransformerTrainingArguments( | 
					
						
						|  |  | 
					
						
						|  | output_dir=f"models/{run_name}", | 
					
						
						|  |  | 
					
						
						|  | num_train_epochs=1, | 
					
						
						|  | per_device_train_batch_size=2048, | 
					
						
						|  | per_device_eval_batch_size=2048, | 
					
						
						|  | learning_rate=2e-1, | 
					
						
						|  | warmup_ratio=0.1, | 
					
						
						|  | fp16=False, | 
					
						
						|  | bf16=True, | 
					
						
						|  | batch_sampler=BatchSamplers.NO_DUPLICATES, | 
					
						
						|  |  | 
					
						
						|  | eval_strategy="steps", | 
					
						
						|  | eval_steps=250, | 
					
						
						|  | save_strategy="steps", | 
					
						
						|  | save_steps=250, | 
					
						
						|  | save_total_limit=2, | 
					
						
						|  | logging_steps=100, | 
					
						
						|  | logging_first_step=True, | 
					
						
						|  | run_name=run_name, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | random.seed(12) | 
					
						
						|  | queries = dict(zip(eval_dataset["id"], eval_dataset["question"])) | 
					
						
						|  | corpus = ( | 
					
						
						|  | {qid: dataset[qid]["answer"] for qid in queries} | | 
					
						
						|  | {qid: dataset[qid]["answer"] for qid in random.sample(range(len(dataset)), 20_000)} | 
					
						
						|  | ) | 
					
						
						|  | relevant_docs = {qid: {qid} for qid in eval_dataset["id"]} | 
					
						
						|  | evaluators = [] | 
					
						
						|  | for dim in loss.matryoshka_dims: | 
					
						
						|  | evaluators.append(InformationRetrievalEvaluator( | 
					
						
						|  | corpus=corpus, | 
					
						
						|  | queries=queries, | 
					
						
						|  | relevant_docs=relevant_docs, | 
					
						
						|  | show_progress_bar=True, | 
					
						
						|  | name=f"gooaq-{dim}-dev", | 
					
						
						|  | truncate_dim=dim, | 
					
						
						|  | score_functions={"cosine": cos_sim}, | 
					
						
						|  | )) | 
					
						
						|  | dev_evaluator = SequentialEvaluator(evaluators) | 
					
						
						|  | dev_evaluator(model) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | trainer = SentenceTransformerTrainer( | 
					
						
						|  | model=model, | 
					
						
						|  | args=args, | 
					
						
						|  | train_dataset=train_dataset.remove_columns("id"), | 
					
						
						|  | eval_dataset=eval_dataset.remove_columns("id"), | 
					
						
						|  | loss=loss, | 
					
						
						|  | evaluator=dev_evaluator, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | trainer.train() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | dev_evaluator(model) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | model.save_pretrained(f"models/{run_name}/final") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | model.push_to_hub(run_name, private=True) | 
					
						
						|  |  |