Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| import random | |
| import pandas as pd | |
| import torch | |
| from torch.utils.data import Dataset, DataLoader | |
| from enumerator import SmilesEnumerator | |
| from utils import split | |
| PAD = 0 | |
| MAX_LEN = 220 | |
| class Randomizer(object): | |
| def __init__(self): | |
| self.sme = SmilesEnumerator() | |
| def __call__(self, sm): | |
| sm_r = self.sme.randomize_smiles(sm) # Random transoform | |
| if sm_r is None: | |
| sm_spaced = split(sm) # Spacing | |
| else: | |
| sm_spaced = split(sm_r) # Spacing | |
| sm_split = sm_spaced.split() | |
| if len(sm_split)<=MAX_LEN - 2: | |
| return sm_split # List | |
| else: | |
| return split(sm).split() | |
| def random_transform(self, sm): | |
| ''' | |
| function: Random transformation for SMILES. It may take some time. | |
| input: A SMILES | |
| output: A randomized SMILES | |
| ''' | |
| return self.sme.randomize_smiles(sm) | |
| class Seq2seqDataset(Dataset): | |
| def __init__(self, smiles, vocab, seq_len=220, transform=Randomizer()): | |
| self.smiles = smiles | |
| self.vocab = vocab | |
| self.seq_len = seq_len | |
| self.transform = transform | |
| def __len__(self): | |
| return len(self.smiles) | |
| def __getitem__(self, item): | |
| sm = self.smiles[item] | |
| sm = self.transform(sm) # List | |
| content = [self.vocab.stoi.get(token, self.vocab.unk_index) for token in sm] | |
| X = [self.vocab.sos_index] + content + [self.vocab.eos_index] | |
| padding = [self.vocab.pad_index]*(self.seq_len - len(X)) | |
| X.extend(padding) | |
| return torch.tensor(X) | |
