import pandas as pd import typing as T import selfies as sf from tokenizers import ByteLevelBPETokenizer from tokenizers import Tokenizer, processors, models from tokenizers.implementations import BaseTokenizer, ByteLevelBPETokenizer import massspecgym.utils as utils from massspecgym.definitions import PAD_TOKEN, SOS_TOKEN, EOS_TOKEN, UNK_TOKEN class SpecialTokensBaseTokenizer(BaseTokenizer): def __init__( self, tokenizer: Tokenizer, max_len: T.Optional[int] = None, ): """Initialize the base tokenizer with special tokens performing padding and truncation.""" super().__init__(tokenizer) # Save essential attributes self.pad_token = PAD_TOKEN self.sos_token = SOS_TOKEN self.eos_token = EOS_TOKEN self.unk_token = UNK_TOKEN self.max_length = max_len # Add special tokens self.add_special_tokens([self.pad_token, self.sos_token, self.eos_token, self.unk_token]) # Get token IDs self.pad_token_id = self.token_to_id(self.pad_token) self.sos_token_id = self.token_to_id(self.sos_token) self.eos_token_id = self.token_to_id(self.eos_token) self.unk_token_id = self.token_to_id(self.unk_token) # Enable padding self.enable_padding( direction="right", pad_token=self.pad_token, pad_id=self.pad_token_id, length=max_len, ) # Enable truncation self.enable_truncation(max_len) # Set post-processing to add SOS and EOS tokens self._tokenizer.post_processor = processors.TemplateProcessing( single=f"{self.sos_token} $A {self.eos_token}", pair=f"{self.sos_token} $A {self.eos_token} {self.sos_token} $B {self.eos_token}", special_tokens=[ (self.sos_token, self.sos_token_id), (self.eos_token, self.eos_token_id), ], ) class SelfiesTokenizer(SpecialTokensBaseTokenizer): def __init__( self, selfies_train: T.Optional[T.Union[str, T.List[str]]] = None, **kwargs ): """ Initialize the SELFIES tokenizer with optional training data to build a vocanulary. Args: selfies_train (str or list of str): Either a list of SELFIES strings to build the vocabulary from, or a `semantic_robust_alphabet` string indicating the usahe of `selfies.get_semantic_robust_alphabet()` alphabet. If None, the MassSpecGym training molecules will be used. """ if selfies_train == 'semantic_robust_alphabet': alphabet = list(sorted(sf.get_semantic_robust_alphabet())) else: if not selfies_train: selfies_train = utils.load_train_mols() selfies = [sf.encoder(s, strict=False) for s in selfies_train] else: selfies = selfies_train alphabet = list(sorted(sf.get_alphabet_from_selfies(selfies))) vocab = {symbol: i for i, symbol in enumerate(alphabet)} vocab[UNK_TOKEN] = len(vocab) tokenizer = Tokenizer(models.WordLevel(vocab=vocab, unk_token=UNK_TOKEN)) super().__init__(tokenizer, **kwargs) def encode(self, text: str, add_special_tokens: bool = True) -> Tokenizer: """Encodes a SMILES string into a list of SELFIES token IDs.""" selfies_string = sf.encoder(text, strict=False) selfies_tokens = list(sf.split_selfies(selfies_string)) return super().encode( selfies_tokens, is_pretokenized=True, add_special_tokens=add_special_tokens ) def decode(self, token_ids: T.List[int], skip_special_tokens: bool = True) -> str: """Decodes a list of SELFIES token IDs back into a SMILES string.""" selfies_string = super().decode( token_ids, skip_special_tokens=skip_special_tokens ) selfies_string = self._decode_wordlevel_str_to_selfies(selfies_string) return sf.decoder(selfies_string) def encode_batch( self, texts: T.List[str], add_special_tokens: bool = True ) -> T.List[Tokenizer]: """Encodes a batch of SMILES strings into a list of SELFIES token IDs.""" selfies_strings = [ list(sf.split_selfies(sf.encoder(text, strict=False))) for text in texts ] return super().encode_batch( selfies_strings, is_pretokenized=True, add_special_tokens=add_special_tokens ) def decode_batch( self, token_ids_batch: T.List[T.List[int]], skip_special_tokens: bool = True ) -> T.List[str]: """Decodes a batch of SELFIES token IDs back into SMILES strings.""" selfies_strings = super().decode_batch( token_ids_batch, skip_special_tokens=skip_special_tokens ) return [ sf.decoder( self._decode_wordlevel_str_to_selfies( selfies_string ) ) for selfies_string in selfies_strings ] def _decode_wordlevel_str_to_selfies(self, text: str) -> str: """Converts a WordLevel string back to a SELFIES string.""" return text.replace(" ", "") class SmilesBPETokenizer(SpecialTokensBaseTokenizer): def __init__(self, smiles_pth: T.Optional[str] = None, **kwargs): """ Initialize the BPE tokenizer for SMILES strings, with optional training data. Args: smiles_pth (str): Path to a file containing SMILES strings to train the tokenizer on. If None, the MassSpecGym training molecules will be used. """ tokenizer = ByteLevelBPETokenizer() if smiles_pth: tokenizer.train(smiles_pth) else: smiles = utils.load_unlabeled_mols("smiles").tolist() smiles += utils.load_train_mols().tolist() print(f"Training tokenizer on {len(smiles)} SMILES strings.") tokenizer.train_from_iterator(smiles) super().__init__(tokenizer, **kwargs)