|  | import random | 
					
						
						|  | import logging | 
					
						
						|  | from datasets import load_dataset, Dataset, DatasetDict | 
					
						
						|  | from sentence_transformers import ( | 
					
						
						|  | SentenceTransformer, | 
					
						
						|  | SentenceTransformerTrainer, | 
					
						
						|  | SentenceTransformerTrainingArguments, | 
					
						
						|  | SentenceTransformerModelCardData, | 
					
						
						|  | ) | 
					
						
						|  | from sentence_transformers.losses import MatryoshkaLoss, MultipleNegativesRankingLoss | 
					
						
						|  | from sentence_transformers.training_args import BatchSamplers, MultiDatasetBatchSamplers | 
					
						
						|  | from sentence_transformers.models.StaticEmbedding import StaticEmbedding | 
					
						
						|  |  | 
					
						
						|  | from transformers import AutoTokenizer | 
					
						
						|  |  | 
					
						
						|  | logging.basicConfig( | 
					
						
						|  | format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO | 
					
						
						|  | ) | 
					
						
						|  | random.seed(12) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def load_train_eval_datasets(): | 
					
						
						|  | """ | 
					
						
						|  | Either load the train and eval datasets from disk or load them from the datasets library & save them to disk. | 
					
						
						|  |  | 
					
						
						|  | Upon saving to disk, we quit() to ensure that the datasets are not loaded into memory before training. | 
					
						
						|  | """ | 
					
						
						|  | try: | 
					
						
						|  | train_dataset = DatasetDict.load_from_disk("datasets/train_dataset") | 
					
						
						|  | eval_dataset = DatasetDict.load_from_disk("datasets/eval_dataset") | 
					
						
						|  | return train_dataset, eval_dataset | 
					
						
						|  | except FileNotFoundError: | 
					
						
						|  | print("Loading wikititles dataset...") | 
					
						
						|  | wikititles_dataset = load_dataset("sentence-transformers/parallel-sentences-wikititles", split="train") | 
					
						
						|  | wikititles_dataset_dict = wikititles_dataset.train_test_split(test_size=10_000, seed=12) | 
					
						
						|  | wikititles_train_dataset: Dataset = wikititles_dataset_dict["train"] | 
					
						
						|  | wikititles_eval_dataset: Dataset = wikititles_dataset_dict["test"] | 
					
						
						|  | print("Loaded wikititles dataset.") | 
					
						
						|  |  | 
					
						
						|  | print("Loading tatoeba dataset...") | 
					
						
						|  | tatoeba_dataset = load_dataset("sentence-transformers/parallel-sentences-tatoeba", "all", split="train") | 
					
						
						|  | tatoeba_dataset_dict = tatoeba_dataset.train_test_split(test_size=10_000, seed=12) | 
					
						
						|  | tatoeba_train_dataset: Dataset = tatoeba_dataset_dict["train"] | 
					
						
						|  | tatoeba_eval_dataset: Dataset = tatoeba_dataset_dict["test"] | 
					
						
						|  | print("Loaded tatoeba dataset.") | 
					
						
						|  |  | 
					
						
						|  | print("Loading talks dataset...") | 
					
						
						|  | talks_dataset = load_dataset("sentence-transformers/parallel-sentences-talks", "all", split="train") | 
					
						
						|  | talks_dataset_dict = talks_dataset.train_test_split(test_size=10_000, seed=12) | 
					
						
						|  | talks_train_dataset: Dataset = talks_dataset_dict["train"] | 
					
						
						|  | talks_eval_dataset: Dataset = talks_dataset_dict["test"] | 
					
						
						|  | print("Loaded talks dataset.") | 
					
						
						|  |  | 
					
						
						|  | print("Loading europarl dataset...") | 
					
						
						|  | europarl_dataset = load_dataset("sentence-transformers/parallel-sentences-europarl", "all", split="train[:5000000]") | 
					
						
						|  | europarl_dataset_dict = europarl_dataset.train_test_split(test_size=10_000, seed=12) | 
					
						
						|  | europarl_train_dataset: Dataset = europarl_dataset_dict["train"] | 
					
						
						|  | europarl_eval_dataset: Dataset = europarl_dataset_dict["test"] | 
					
						
						|  | print("Loaded europarl dataset.") | 
					
						
						|  |  | 
					
						
						|  | print("Loading global voices dataset...") | 
					
						
						|  | global_voices_dataset = load_dataset("sentence-transformers/parallel-sentences-global-voices", "all", split="train") | 
					
						
						|  | global_voices_dataset_dict = global_voices_dataset.train_test_split(test_size=10_000, seed=12) | 
					
						
						|  | global_voices_train_dataset: Dataset = global_voices_dataset_dict["train"] | 
					
						
						|  | global_voices_eval_dataset: Dataset = global_voices_dataset_dict["test"] | 
					
						
						|  | print("Loaded global voices dataset.") | 
					
						
						|  |  | 
					
						
						|  | print("Loading jw300 dataset...") | 
					
						
						|  | jw300_dataset = load_dataset("sentence-transformers/parallel-sentences-jw300", "all", split="train") | 
					
						
						|  | jw300_dataset_dict = jw300_dataset.train_test_split(test_size=10_000, seed=12) | 
					
						
						|  | jw300_train_dataset: Dataset = jw300_dataset_dict["train"] | 
					
						
						|  | jw300_eval_dataset: Dataset = jw300_dataset_dict["test"] | 
					
						
						|  | print("Loaded jw300 dataset.") | 
					
						
						|  |  | 
					
						
						|  | print("Loading muse dataset...") | 
					
						
						|  | muse_dataset = load_dataset("sentence-transformers/parallel-sentences-muse", split="train") | 
					
						
						|  | muse_dataset_dict = muse_dataset.train_test_split(test_size=10_000, seed=12) | 
					
						
						|  | muse_train_dataset: Dataset = muse_dataset_dict["train"] | 
					
						
						|  | muse_eval_dataset: Dataset = muse_dataset_dict["test"] | 
					
						
						|  | print("Loaded muse dataset.") | 
					
						
						|  |  | 
					
						
						|  | print("Loading wikimatrix dataset...") | 
					
						
						|  | wikimatrix_dataset = load_dataset("sentence-transformers/parallel-sentences-wikimatrix", "all", split="train") | 
					
						
						|  | wikimatrix_dataset_dict = wikimatrix_dataset.train_test_split(test_size=10_000, seed=12) | 
					
						
						|  | wikimatrix_train_dataset: Dataset = wikimatrix_dataset_dict["train"] | 
					
						
						|  | wikimatrix_eval_dataset: Dataset = wikimatrix_dataset_dict["test"] | 
					
						
						|  | print("Loaded wikimatrix dataset.") | 
					
						
						|  |  | 
					
						
						|  | print("Loading opensubtitles dataset...") | 
					
						
						|  | opensubtitles_dataset = load_dataset("sentence-transformers/parallel-sentences-opensubtitles", "all", split="train[:5000000]") | 
					
						
						|  | opensubtitles_dataset_dict = opensubtitles_dataset.train_test_split(test_size=10_000, seed=12) | 
					
						
						|  | opensubtitles_train_dataset: Dataset = opensubtitles_dataset_dict["train"] | 
					
						
						|  | opensubtitles_eval_dataset: Dataset = opensubtitles_dataset_dict["test"] | 
					
						
						|  | print("Loaded opensubtitles dataset.") | 
					
						
						|  |  | 
					
						
						|  | print("Loading stackexchange dataset...") | 
					
						
						|  | stackexchange_dataset = load_dataset("sentence-transformers/stackexchange-duplicates", "post-post-pair", split="train") | 
					
						
						|  | stackexchange_dataset_dict = stackexchange_dataset.train_test_split(test_size=10_000, seed=12) | 
					
						
						|  | stackexchange_train_dataset: Dataset = stackexchange_dataset_dict["train"] | 
					
						
						|  | stackexchange_eval_dataset: Dataset = stackexchange_dataset_dict["test"] | 
					
						
						|  | print("Loaded stackexchange dataset.") | 
					
						
						|  |  | 
					
						
						|  | print("Loading quora dataset...") | 
					
						
						|  | quora_dataset = load_dataset("sentence-transformers/quora-duplicates", "triplet", split="train") | 
					
						
						|  | quora_dataset_dict = quora_dataset.train_test_split(test_size=10_000, seed=12) | 
					
						
						|  | quora_train_dataset: Dataset = quora_dataset_dict["train"] | 
					
						
						|  | quora_eval_dataset: Dataset = quora_dataset_dict["test"] | 
					
						
						|  | print("Loaded quora dataset.") | 
					
						
						|  |  | 
					
						
						|  | print("Loading wikianswers duplicates dataset...") | 
					
						
						|  | wikianswers_duplicates_dataset = load_dataset("sentence-transformers/wikianswers-duplicates", split="train[:10000000]") | 
					
						
						|  | wikianswers_duplicates_dict = wikianswers_duplicates_dataset.train_test_split(test_size=10_000, seed=12) | 
					
						
						|  | wikianswers_duplicates_train_dataset: Dataset = wikianswers_duplicates_dict["train"] | 
					
						
						|  | wikianswers_duplicates_eval_dataset: Dataset = wikianswers_duplicates_dict["test"] | 
					
						
						|  | print("Loaded wikianswers duplicates dataset.") | 
					
						
						|  |  | 
					
						
						|  | print("Loading all nli dataset...") | 
					
						
						|  | all_nli_train_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="train") | 
					
						
						|  | all_nli_eval_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="dev") | 
					
						
						|  | print("Loaded all nli dataset.") | 
					
						
						|  |  | 
					
						
						|  | print("Loading simple wiki dataset...") | 
					
						
						|  | simple_wiki_dataset = load_dataset("sentence-transformers/simple-wiki", split="train") | 
					
						
						|  | simple_wiki_dataset_dict = simple_wiki_dataset.train_test_split(test_size=10_000, seed=12) | 
					
						
						|  | simple_wiki_train_dataset: Dataset = simple_wiki_dataset_dict["train"] | 
					
						
						|  | simple_wiki_eval_dataset: Dataset = simple_wiki_dataset_dict["test"] | 
					
						
						|  | print("Loaded simple wiki dataset.") | 
					
						
						|  |  | 
					
						
						|  | print("Loading altlex dataset...") | 
					
						
						|  | altlex_dataset = load_dataset("sentence-transformers/altlex", split="train") | 
					
						
						|  | altlex_dataset_dict = altlex_dataset.train_test_split(test_size=10_000, seed=12) | 
					
						
						|  | altlex_train_dataset: Dataset = altlex_dataset_dict["train"] | 
					
						
						|  | altlex_eval_dataset: Dataset = altlex_dataset_dict["test"] | 
					
						
						|  | print("Loaded altlex dataset.") | 
					
						
						|  |  | 
					
						
						|  | print("Loading flickr30k captions dataset...") | 
					
						
						|  | flickr30k_captions_dataset = load_dataset("sentence-transformers/flickr30k-captions", split="train") | 
					
						
						|  | flickr30k_captions_dataset_dict = flickr30k_captions_dataset.train_test_split(test_size=10_000, seed=12) | 
					
						
						|  | flickr30k_captions_train_dataset: Dataset = flickr30k_captions_dataset_dict["train"] | 
					
						
						|  | flickr30k_captions_eval_dataset: Dataset = flickr30k_captions_dataset_dict["test"] | 
					
						
						|  | print("Loaded flickr30k captions dataset.") | 
					
						
						|  |  | 
					
						
						|  | print("Loading coco captions dataset...") | 
					
						
						|  | coco_captions_dataset = load_dataset("sentence-transformers/coco-captions", split="train") | 
					
						
						|  | coco_captions_dataset_dict = coco_captions_dataset.train_test_split(test_size=10_000, seed=12) | 
					
						
						|  | coco_captions_train_dataset: Dataset = coco_captions_dataset_dict["train"] | 
					
						
						|  | coco_captions_eval_dataset: Dataset = coco_captions_dataset_dict["test"] | 
					
						
						|  | print("Loaded coco captions dataset.") | 
					
						
						|  |  | 
					
						
						|  | print("Loading nli for simcse dataset...") | 
					
						
						|  | nli_for_simcse_dataset = load_dataset("sentence-transformers/nli-for-simcse", "triplet", split="train") | 
					
						
						|  | nli_for_simcse_dataset_dict = nli_for_simcse_dataset.train_test_split(test_size=10_000, seed=12) | 
					
						
						|  | nli_for_simcse_train_dataset: Dataset = nli_for_simcse_dataset_dict["train"] | 
					
						
						|  | nli_for_simcse_eval_dataset: Dataset = nli_for_simcse_dataset_dict["test"] | 
					
						
						|  | print("Loaded nli for simcse dataset.") | 
					
						
						|  |  | 
					
						
						|  | print("Loading negation dataset...") | 
					
						
						|  | negation_dataset = load_dataset("jinaai/negation-dataset", split="train") | 
					
						
						|  | negation_dataset_dict = negation_dataset.train_test_split(test_size=100, seed=12) | 
					
						
						|  | negation_train_dataset: Dataset = negation_dataset_dict["train"] | 
					
						
						|  | negation_eval_dataset: Dataset = negation_dataset_dict["test"] | 
					
						
						|  | print("Loaded negation dataset.") | 
					
						
						|  |  | 
					
						
						|  | train_dataset = DatasetDict({ | 
					
						
						|  | "wikititles": wikititles_train_dataset, | 
					
						
						|  | "tatoeba": tatoeba_train_dataset, | 
					
						
						|  | "talks": talks_train_dataset, | 
					
						
						|  | "europarl": europarl_train_dataset, | 
					
						
						|  | "global_voices": global_voices_train_dataset, | 
					
						
						|  | "jw300": jw300_train_dataset, | 
					
						
						|  | "muse": muse_train_dataset, | 
					
						
						|  | "wikimatrix": wikimatrix_train_dataset, | 
					
						
						|  | "opensubtitles": opensubtitles_train_dataset, | 
					
						
						|  | "stackexchange": stackexchange_train_dataset, | 
					
						
						|  | "quora": quora_train_dataset, | 
					
						
						|  | "wikianswers_duplicates": wikianswers_duplicates_train_dataset, | 
					
						
						|  | "all_nli": all_nli_train_dataset, | 
					
						
						|  | "simple_wiki": simple_wiki_train_dataset, | 
					
						
						|  | "altlex": altlex_train_dataset, | 
					
						
						|  | "flickr30k_captions": flickr30k_captions_train_dataset, | 
					
						
						|  | "coco_captions": coco_captions_train_dataset, | 
					
						
						|  | "nli_for_simcse": nli_for_simcse_train_dataset, | 
					
						
						|  | "negation": negation_train_dataset, | 
					
						
						|  | }) | 
					
						
						|  | eval_dataset = DatasetDict({ | 
					
						
						|  | "wikititles": wikititles_eval_dataset, | 
					
						
						|  | "tatoeba": tatoeba_eval_dataset, | 
					
						
						|  | "talks": talks_eval_dataset, | 
					
						
						|  | "europarl": europarl_eval_dataset, | 
					
						
						|  | "global_voices": global_voices_eval_dataset, | 
					
						
						|  | "jw300": jw300_eval_dataset, | 
					
						
						|  | "muse": muse_eval_dataset, | 
					
						
						|  | "wikimatrix": wikimatrix_eval_dataset, | 
					
						
						|  | "opensubtitles": opensubtitles_eval_dataset, | 
					
						
						|  | "stackexchange": stackexchange_eval_dataset, | 
					
						
						|  | "quora": quora_eval_dataset, | 
					
						
						|  | "wikianswers_duplicates": wikianswers_duplicates_eval_dataset, | 
					
						
						|  | "all_nli": all_nli_eval_dataset, | 
					
						
						|  | "simple_wiki": simple_wiki_eval_dataset, | 
					
						
						|  | "altlex": altlex_eval_dataset, | 
					
						
						|  | "flickr30k_captions": flickr30k_captions_eval_dataset, | 
					
						
						|  | "coco_captions": coco_captions_eval_dataset, | 
					
						
						|  | "nli_for_simcse": nli_for_simcse_eval_dataset, | 
					
						
						|  | "negation": negation_eval_dataset, | 
					
						
						|  | }) | 
					
						
						|  |  | 
					
						
						|  | train_dataset.save_to_disk("datasets/train_dataset") | 
					
						
						|  | eval_dataset.save_to_disk("datasets/eval_dataset") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | quit() | 
					
						
						|  |  | 
					
						
						|  | def main(): | 
					
						
						|  |  | 
					
						
						|  | static_embedding = StaticEmbedding(AutoTokenizer.from_pretrained("google-bert/bert-base-multilingual-uncased"), embedding_dim=1024) | 
					
						
						|  | model = SentenceTransformer( | 
					
						
						|  | modules=[static_embedding], | 
					
						
						|  | model_card_data=SentenceTransformerModelCardData( | 
					
						
						|  | license="apache-2.0", | 
					
						
						|  | model_name="Static Embeddings with BERT Multilingual uncased tokenizer finetuned on various datasets", | 
					
						
						|  | ), | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | train_dataset, eval_dataset = load_train_eval_datasets() | 
					
						
						|  | print(train_dataset) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | loss = MultipleNegativesRankingLoss(model) | 
					
						
						|  | loss = MatryoshkaLoss(model, loss, matryoshka_dims=[32, 64, 128, 256, 512, 1024]) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | run_name = "static-similarity-mrl-multilingual-v1" | 
					
						
						|  | args = SentenceTransformerTrainingArguments( | 
					
						
						|  |  | 
					
						
						|  | output_dir=f"models/{run_name}", | 
					
						
						|  |  | 
					
						
						|  | num_train_epochs=1, | 
					
						
						|  | per_device_train_batch_size=2048, | 
					
						
						|  | per_device_eval_batch_size=2048, | 
					
						
						|  | learning_rate=2e-1, | 
					
						
						|  | warmup_ratio=0.1, | 
					
						
						|  | fp16=False, | 
					
						
						|  | bf16=True, | 
					
						
						|  | batch_sampler=BatchSamplers.NO_DUPLICATES, | 
					
						
						|  | multi_dataset_batch_sampler=MultiDatasetBatchSamplers.PROPORTIONAL, | 
					
						
						|  |  | 
					
						
						|  | eval_strategy="steps", | 
					
						
						|  | eval_steps=1000, | 
					
						
						|  | save_strategy="steps", | 
					
						
						|  | save_steps=1000, | 
					
						
						|  | save_total_limit=2, | 
					
						
						|  | logging_steps=1000, | 
					
						
						|  | logging_first_step=True, | 
					
						
						|  | run_name=run_name, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | trainer = SentenceTransformerTrainer( | 
					
						
						|  | model=model, | 
					
						
						|  | args=args, | 
					
						
						|  | train_dataset=train_dataset, | 
					
						
						|  | eval_dataset=eval_dataset, | 
					
						
						|  | loss=loss, | 
					
						
						|  | ) | 
					
						
						|  | trainer.train() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | model.save_pretrained(f"models/{run_name}/final") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | model.push_to_hub(run_name, private=True) | 
					
						
						|  |  | 
					
						
						|  | if __name__ == "__main__": | 
					
						
						|  | main() |