|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.nn.modules.normalization import LayerNorm |
|
|
import random |
|
|
|
|
|
from utilities.constants import * |
|
|
from utilities.device import get_device |
|
|
|
|
|
from .positional_encoding import PositionalEncoding |
|
|
from .rpr import TransformerEncoderRPR, TransformerEncoderLayerRPR |
|
|
|
|
|
|
|
|
|
|
|
class MusicTransformer(nn.Module): |
|
|
def __init__(self, n_layers=6, num_heads=8, d_model=512, dim_feedforward=1024, |
|
|
dropout=0.1, max_sequence=2048, rpr=False): |
|
|
super(MusicTransformer, self).__init__() |
|
|
|
|
|
self.dummy = DummyDecoder() |
|
|
|
|
|
self.nlayers = n_layers |
|
|
self.nhead = num_heads |
|
|
self.d_model = d_model |
|
|
self.d_ff = dim_feedforward |
|
|
self.dropout = dropout |
|
|
self.max_seq = max_sequence |
|
|
self.rpr = rpr |
|
|
|
|
|
|
|
|
self.embedding = nn.Embedding(VOCAB_SIZE, self.d_model) |
|
|
|
|
|
|
|
|
self.positional_encoding = PositionalEncoding(self.d_model, self.dropout, self.max_seq) |
|
|
|
|
|
|
|
|
if(not self.rpr): |
|
|
self.transformer = nn.Transformer( |
|
|
d_model=self.d_model, nhead=self.nhead, num_encoder_layers=self.nlayers, |
|
|
num_decoder_layers=0, dropout=self.dropout, |
|
|
dim_feedforward=self.d_ff, custom_decoder=self.dummy |
|
|
) |
|
|
else: |
|
|
encoder_norm = LayerNorm(self.d_model) |
|
|
encoder_layer = TransformerEncoderLayerRPR(self.d_model, self.nhead, self.d_ff, self.dropout, er_len=self.max_seq) |
|
|
encoder = TransformerEncoderRPR(encoder_layer, self.nlayers, encoder_norm) |
|
|
self.transformer = nn.Transformer( |
|
|
d_model=self.d_model, nhead=self.nhead, num_encoder_layers=self.nlayers, |
|
|
num_decoder_layers=0, dropout=self.dropout, |
|
|
dim_feedforward=self.d_ff, custom_decoder=self.dummy, custom_encoder=encoder |
|
|
) |
|
|
|
|
|
|
|
|
self.Wout = nn.Linear(self.d_model, VOCAB_SIZE) |
|
|
self.softmax = nn.Softmax(dim=-1) |
|
|
|
|
|
|
|
|
def forward(self, x, mask=True): |
|
|
|
|
|
if(mask is True): |
|
|
|
|
|
mask = self.transformer.generate_square_subsequent_mask(x.shape[1]).to(x.device) |
|
|
else: |
|
|
mask = None |
|
|
|
|
|
|
|
|
x = self.embedding(x) |
|
|
|
|
|
|
|
|
x = x.permute(1,0,2) |
|
|
|
|
|
x = self.positional_encoding(x) |
|
|
|
|
|
|
|
|
x_out = self.transformer(src=x, tgt=x, src_mask=mask) |
|
|
|
|
|
|
|
|
x_out = x_out.permute(1,0,2) |
|
|
|
|
|
y = self.Wout(x_out) |
|
|
return y |
|
|
|
|
|
|
|
|
def generate(self, primer=None, target_seq_length=1024, beam=0, beam_chance=1.0): |
|
|
assert (not self.training), "Cannot generate while in training mode" |
|
|
|
|
|
print("Generating sequence of max length:", target_seq_length) |
|
|
|
|
|
batch_size = primer.shape[0] |
|
|
gen_seq = torch.full((batch_size, target_seq_length), TOKEN_PAD, dtype=TORCH_LABEL_TYPE, device=get_device()) |
|
|
|
|
|
num_primer = primer.shape[1] |
|
|
gen_seq[..., :num_primer] = primer.type(TORCH_LABEL_TYPE).to(get_device()) |
|
|
|
|
|
cur_i = num_primer |
|
|
while(cur_i < target_seq_length): |
|
|
y = self.softmax(self.forward(gen_seq[..., :cur_i]))[..., :TOKEN_END] |
|
|
token_probs = y[:, cur_i-1, :] |
|
|
|
|
|
if(beam == 0): |
|
|
beam_ran = 2.0 |
|
|
else: |
|
|
beam_ran = random.uniform(0,1) |
|
|
|
|
|
if(beam_ran <= beam_chance): |
|
|
token_probs = token_probs.flatten() |
|
|
top_res, top_i = torch.topk(token_probs, beam) |
|
|
|
|
|
beam_rows = top_i // VOCAB_SIZE |
|
|
beam_cols = top_i % VOCAB_SIZE |
|
|
|
|
|
gen_seq = gen_seq[beam_rows, :] |
|
|
gen_seq[..., cur_i] = beam_cols |
|
|
|
|
|
else: |
|
|
distrib = torch.distributions.categorical.Categorical(probs=token_probs) |
|
|
next_token = distrib.sample() |
|
|
gen_seq[:, cur_i] = next_token |
|
|
|
|
|
if(next_token == TOKEN_END): |
|
|
print("Model called end of sequence at:", cur_i, "/", target_seq_length) |
|
|
break |
|
|
|
|
|
cur_i += 1 |
|
|
if(cur_i % 50 == 0): |
|
|
print(cur_i, "/", target_seq_length) |
|
|
|
|
|
return gen_seq[:, :cur_i] |
|
|
|
|
|
|
|
|
class DummyDecoder(nn.Module): |
|
|
def __init__(self): |
|
|
super(DummyDecoder, self).__init__() |
|
|
|
|
|
def forward(self, tgt, memory, tgt_mask, memory_mask,tgt_key_padding_mask,memory_key_padding_mask, **kwargs): |
|
|
return memory |