Attentionless VOcoder Streaming
Usage
from huggingface_hub import hf_hub_download
import soundfile
import torch
from transformers import Wav2Vec2PreTrainedModel, PretrainedConfig
from torch import nn
import torch.nn.functional as F
class Voc(Wav2Vec2PreTrainedModel):
'''For using different batch_siz -> Voc._flush()
'''
def __init__(self, config=PretrainedConfig()):
super().__init__(config=config)
self.encoder_transformer = VocTransformer()
self.decoder_transformer = VocTransformer()
self.encoder = SEANetEncoder()
self.decoder = SEANetDecoder()
self.sample_rate = 24000
self.quantizer = SplitResidualVectorQuantizer()
self.downsample = BufferConv1d(512, 512, kernel_size=4, stride=2, groups=1, bias=False)
upsample_channel_wise_bug = True
self.upsample = BufferConvTranspose1d(512, 512, kernel_size=4,
groups=512 if upsample_channel_wise_bug else 1,
stride=2, bias=False)
self.frame_rate = 12.5
self.encode_buffer = None
def _flush(self):
'''stream buffers have tensors of old batch size! Voc()._flush() to clean buffers
'''
self.encode_buffer = None # holds unused (incomplete windows of len < 1920) - we need 1920 to produce 1 token
if self.downsample.previous is not None:
self.downsample.previous = None
if self.upsample.partial is not None:
self.upsample.partial = None
for arch in [self.encoder, self.decoder]:
for _m in arch.model:
if type(_m) is SEANetResnetBlock:
for _b in _m.block:
if type(_b) is BufferConv1d:
if _b.previous is not None:
_b.previous = None
if type(_m) is BufferConv1d:
if _m.previous is not None:
_m.previous = None
if type(_m) is BufferConvTranspose1d:
if _m.partial is not None:
_m.partial = None
@torch.no_grad()
def encode(self, x):
'''24KHz audio to codes
x : [bs, 1, 24 KHz]
c : [bs, 8, time] = 1920 audio samples produce 1 time frame (of n_q codebooks)
'''
if self.encode_buffer is not None:
x = torch.cat([self.encode_buffer, x], 2)
_bs, _1, _len = x.shape
num_frames = int(_len / 1920)
leftover = x[:, :, (num_frames+1) * 1920:]
if leftover.shape[2] > 0:
self.encode_buffer = leftover
else:
self.encode_buffer = None
torch.cuda.empty_cache()
if num_frames > 0:
c = []
for n in range(num_frames):
e = self.encoder(x[:, :, n * 1920:(n + 1) * 1920])
e = self.encoder_transformer(e)
e = self.downsample(e)
_c = self.quantizer.encode(e)
c.append(_c)
c = torch.cat(c, 2)
else:
# num_frames = 0 Early exit -> for x.shape[2]<1920 fill conv buffers but can't output token
c = torch.empty(_bs, 0, self.n_q)
return c
@torch.no_grad()
def decode(self, c):
'''codes to 24kHZ audio
c: [bs, 8, n_tokens]
x: [bs, 1, n_tokens * 1920]
'''
_hidden = []
for i in range(c.shape[2]):
x = self.quantizer.decode(c[:, :, i:i+1])
x = self.upsample(x)
x = self.decoder_transformer(x)
x = self.decoder(x)
_hidden.append(x)
return torch.cat(_hidden, 2) # [bs, 1, 24KHz]
class SEANetResnetBlock(nn.Module):
def __init__(
self,
dim,
kernel_sizes=[3, 1],
):
super().__init__()
block = []
for i, kernel_size in enumerate(kernel_sizes):
block += [
nn.ELU(),
BufferConv1d(
dim if i == 0 else dim // 2,
dim // 2 if i == 0 else dim,
kernel_size=kernel_size,
bias=True,
),
]
self.block = nn.Sequential(*block)
def forward(self, x):
return x + self.block(x)
class SEANetEncoder(nn.Module):
def __init__(
self,
channels=1, # DOES NOT SUPPORT STEREO
dimension=512,
n_filters=64,
ratios=[8, 6, 5, 4],
kernel_size=7,
last_kernel_size=3,
):
super().__init__()
self.ratios = list(reversed(ratios))
del ratios
mult = 1
model=[
BufferConv1d(
channels,
mult * n_filters,
kernel_size,
bias=True
)
]
for i, ratio in enumerate(self.ratios):
model += [SEANetResnetBlock(mult * n_filters),
nn.ELU(),
BufferConv1d(mult * n_filters,
mult * n_filters * 2,
kernel_size=ratio * 2,
stride=ratio,
bias=True)]
mult *= 2
# ENDFOR
model += [nn.ELU(),
BufferConv1d(mult * n_filters,
dimension,
last_kernel_size,
bias=True)]
self.model = nn.Sequential(*model)
def forward(self, x):
return self.model(x)
class SEANetDecoder(nn.Module):
def __init__(
self,
channels=1,
dimension=512,
n_filters=64,
ratios=[8, 6, 5, 4],
kernel_size=7,
last_kernel_size=3):
super().__init__()
mult = int(2 ** len(ratios))
model = [BufferConv1d(dimension,
mult * n_filters,
kernel_size,
bias=True)]
#UP
for i, ratio in enumerate(ratios):
model += [nn.ELU(),
BufferConvTranspose1d(mult * n_filters,
mult * n_filters // 2,
kernel_size=ratio * 2,
stride=ratio,
bias=True),
SEANetResnetBlock(mult * n_filters // 2)]
mult //= 2
# LAST
model += [
nn.ELU(),
BufferConv1d(
n_filters,
channels,
last_kernel_size,
bias=True
),
]
self.model = nn.Sequential(*model)
def forward(self, x):
return self.model(x)
class BufferConv1d(nn.Conv1d):
def __init__(self,
*args,
**kwargs):
super().__init__(*args, **kwargs)
self.previous = None
def forward(self, x):
k = self.kernel_size[0]
if self.previous is not None:
x = torch.cat([self.previous, x], 2)
else: # If self.previous is None => Use zero pad
if k == 3:
p = (2, 0)
x = F.pad(x, p, mode='replicate', value=0.0) # skip connections SeaNetResBlk
elif k == 4: # ConvTrUpsample is the first conv encountered by decode replicate solves pulse
p = (3, 0)
x = F.pad(x, p, mode='replicate', value=0.0)
elif k == 7:
p = (6, 0)
x = F.pad(x, p, mode='replicate', value=0.0)
elif k == 16:
p = (2, 0)
x = F.pad(x, p, mode='replicate', value=0.0) # THis can be also constant w/o pulse occur
num_frames = int( (x.shape[2] - self.kernel_size[0]) / self.stride[0] ) + 1 # +1 is: k starts at left of x and doing (I-k)/s jumps
offset = num_frames * self.stride[0]
self.previous = x[..., offset:]
return super().forward(x)
class BufferConvTranspose1d(nn.ConvTranspose1d):
# kernel 5 has only 1 pixel for input (cloned)
# https://distill.pub/2016/deconv-checkerboard/
def __init__(self,
*args,
**kwargs):
super().__init__(*args,
**kwargs)
self.partial = None
def forward(self, x):
out = super().forward(x)
OT = out.shape[2]
invalid_steps = self.kernel_size[0] - self.stride[0]
if self.partial is not None:
PT = self.partial.shape[-1]
if self.bias is not None:
out[..., :PT] += self.partial - self.bias[:, None]
else:
out[..., :PT] += self.partial # for ConvTrUpsample1d
invalid_steps = self.kernel_size[0] - self.stride[0]
self.partial = out[..., OT - invalid_steps :]
out = out[...,:OT - invalid_steps]
return out
class CodeBook(nn.Module):
def __init__(self, dim, codebook_size):
super().__init__()
self.register_buffer('_e', torch.zeros(codebook_size, dim))
def encode(self, x):
dist = torch.cdist(
x.transpose(1, 2), # [bs, time, 256]
self._e[None, :, :] # [1, 2048, 256]
)
codes = dist.argmin(2)
return codes
def decode(self, codes):
quantized = F.embedding(codes, self._e)
return quantized.transpose(1, 2) # [1, 256, time]
class SplitResidualVectorQuantizer(nn.Module):
def __init__(self,
n_q=None):
super().__init__()
self.in_proj_s = torch.nn.Conv1d(512, 256, 1, bias=False)
self.in_proj_a = torch.nn.Conv1d(512, 256, 1, bias=False)
self.out_proj_s = torch.nn.Conv1d(256, 512, 1, bias=False) # reused for all _acoustic_books
self.out_proj_a = torch.nn.Conv1d(256, 512, 1, bias=False)
self.layers = nn.ModuleList([CodeBook(dim=256, codebook_size=2048) for _ in range(18)])
# self._acoustic_books = range(1, 16) # Official Mimi
# CODE BOOKS
# Here we re use RVQ codebooks for higher fidelity!
# Exclude 0 here as it has different proj (in_proj_s)
self._acoustic_books = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 17, 17, 17, 17]
def encode(self, x):
indices = self.layers[0].encode(self.in_proj_s(x)) # integers
all_indices = [ indices[:, None, :], ]
x = self.in_proj_a(x)
for _cb in self._acoustic_books:
indices = self.layers[_cb].encode(x)
x = x - self.layers[_cb].decode(indices)
all_indices.append(indices[:, None, :])
codes = torch.cat(all_indices, 1)
return codes
def decode(self, codes):
_s = self.layers[0].decode(codes[:, 0, :])
_a = torch.zeros([1, 1], device=codes.device)
for i, _cb in enumerate(self._acoustic_books):
_a = _a + self.layers[_cb].decode(codes[:, i+1, :])
return self.out_proj_s(_s) + self.out_proj_a(_a) # [bs, 512, time]
class VocAttention(nn.Module):
def __init__(self,
embed_dim):
super().__init__()
self.fused_proj = nn.Parameter(torch.zeros(embed_dim, embed_dim))
def forward(self, x):
'''bypass of streaming training'''
if x.shape[1] > 1:
x = x.mean(1, keepdims=True)
x = torch.matmul(x, self.fused_proj)
return x # FFN broadcasts to x.shape[1]=2
class VocTransformerLayer(nn.Module):
def __init__(self, d_model=512, dim_feedforward=2048):
super().__init__()
self.self_attn = VocAttention(embed_dim=d_model)
self.norm1 = nn.LayerNorm(d_model, eps=1e-5)
self.norm2 = nn.LayerNorm(d_model, eps=1e-5)
self.linear1 = nn.Linear(d_model, dim_feedforward, bias=False)
self.linear2 = nn.Linear(dim_feedforward, d_model, bias=False)
def forward(self, x):
x = x + self.self_attn(self.norm1(x))
return x + self.linear2(F.gelu(self.linear1(self.norm2(x))))
class VocTransformer(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.ModuleList(VocTransformerLayer() for _ in range(8))
def forward(self, x):
x = x.transpose(1, 2)
for la in self.layers:
x = la(x)
return x.transpose(1, 2)
device = 'cpu' #'cuda:0'
model = Voc.from_pretrained('ivao0/voc').to(device)
x, _ = soundfile.read(hf_hub_download(repo_id='ivao0/voc', filename='true.wav')) # 24 KHz
x = torch.from_numpy(x[None, None, :]).to(dtype=torch.float, device=device)
codes = model.encode(x) # [bs, len(_acoustic_books) + 1, T]
y = model.decode(codes) # audio signal 24KHz
soundfile.write('reconstruct.wav', y[0, 0, :].cpu().numpy(), 24000)
model._flush() # For encode()/decode() for different batch size
- Downloads last month
- 695
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
๐
Ask for provider support
Model tree for ivao0/voc
Base model
kyutai/mimi