Spaces:
Sleeping
Sleeping
| import torch | |
| from torch import nn | |
| from data import Tokenizer | |
| class ResidualBlock(nn.Module): | |
| def __init__(self, num_channels, dropout=0.5): | |
| super(ResidualBlock, self).__init__() | |
| self.conv1 = nn.Conv1d(num_channels, num_channels, kernel_size=3, padding=1) | |
| self.bn1 = nn.BatchNorm1d(num_channels) | |
| self.conv2 = nn.Conv1d(num_channels, num_channels, kernel_size=3, padding=1) | |
| self.bn2 = nn.BatchNorm1d(num_channels) | |
| self.prelu = nn.PReLU() | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x): | |
| residual = x | |
| x = self.prelu(self.bn1(self.conv1(x))) | |
| x = self.dropout(x) | |
| x = self.bn2(self.conv2(x)) | |
| x = self.prelu(x) | |
| x = self.dropout(x) | |
| x += residual # shouldn't it be after activation function? | |
| return x | |
| class Seq2SeqCNN(nn.Module): | |
| # def __init__(self, dict_size_src, dict_size_trg, embedding_dim, num_channels, num_residual_blocks, dropout=0.5): | |
| def __init__(self, config): | |
| dict_size_src = config['dict_size_src'] | |
| dict_size_trg = config['dict_size_trg'] | |
| embedding_dim = config['embedding_dim'] | |
| num_channels = config['num_channels'] | |
| num_residual_blocks = config['num_residual_blocks'] | |
| dropout = config['dropout'] | |
| many_to_one = config['many_to_one'] | |
| self.config = config | |
| super(Seq2SeqCNN, self).__init__() | |
| self.embedding = nn.Embedding(dict_size_src, embedding_dim) | |
| self.conv = nn.Conv1d(embedding_dim, num_channels, kernel_size=3, padding=1) | |
| self.bn = nn.BatchNorm1d(num_channels) | |
| self.residual_blocks = nn.Sequential( | |
| *(ResidualBlock(num_channels, dropout) for _ in range(num_residual_blocks)) | |
| # Add as many blocks as required | |
| ) | |
| self.fc = nn.Linear(num_channels, dict_size_trg*many_to_one) | |
| self.dropout = nn.Dropout(dropout) | |
| self.dict_size_trg = dict_size_trg | |
| def forward(self, src): | |
| # src: (batch_size, seq_len) | |
| batch_size = src.size(0) | |
| embedded = self.embedding(src).permute(0, 2, 1) # (bsize, emb_dim, seq_len) | |
| # print('embedded:', embedded.shape) | |
| conv_out0 = self.conv(embedded) # (bsize, num_channels, seq_len) | |
| # print('conv_out0:', conv_out0.shape) | |
| # conv_out = embedded | |
| conv_out = self.dropout(torch.relu(self.bn(conv_out0))) | |
| # conv_out = conv_out0 | |
| res_out = self.residual_blocks(conv_out) | |
| # print('res_out:', res_out.shape) | |
| res_out = res_out + conv_out | |
| # res_out = torch.cat([res_out, embedded], dim=1) | |
| out = self.fc(self.dropout(res_out.permute(0, 2, 1))) # permute back to original | |
| out = out.view(batch_size, -1, self.config['many_to_one'], self.dict_size_trg) | |
| return out | |
| def init_model(path, device="cpu"): | |
| d = torch.load(path, map_location=device) | |
| state_dict = d['state_dict'] | |
| model = Seq2SeqCNN(d['config']).to(device) | |
| model.load_state_dict(state_dict) | |
| return model | |
| def _predict(model, src, device): | |
| model.eval() | |
| src = src.to(device) | |
| output = model(src) | |
| _, pred = torch.max(output, dim=-1) | |
| # output = torch.softmax(output, dim=3) | |
| # print(output.shape) | |
| # pred = torch.multinomial(output.view(-1, output.size(-1)), 1) | |
| # pred = pred.reshape(output.size()[:-1]) | |
| # print(pred.shape) | |
| return pred | |
| def predict(model, tokenizer: "Tokenizer", text:str, device): | |
| print('text:', text) | |
| if not text: return '' | |
| text_encoded = tokenizer.encode_src(text) | |
| batch = text_encoded.unsqueeze(0) | |
| prd = _predict(model, batch, device)[0] | |
| prd = prd[batch[0] != tokenizer.src_pad_idx,:] | |
| predicted_text = ''.join(tokenizer.decode_trg(prd)) | |
| print('predicted_text:', repr(predicted_text)) | |
| return predicted_text # .replace('\u200c', '') | |