PicoAudio2 / models /content_encoder /caption_encoder.py
rookie9's picture
Upload 77 files
f582ec6 verified
from typing import Any
import torch
import torch.nn as nn
import random
from utils.audiotime_event_merge import replace_event_synonyms
def decode_data(line_onset_str, latent_length):
"""
Extracts a timestamp matrix (event onset indices) from a formatted onset string.
Args:
line_onset_str (str): String containing event names and onset intervals,
formatted like "event1__start1-end1_start2-end2--event2__start1-end1".
latent_length (int): Length of the output matrix.
Returns:
line_onset_index (torch.Tensor): Matrix of shape [4, latent_length],
line_event (list): List of event names extracted from the onset string.
Notes:
- 24000 is the audio sample rate.
- 480 is the downsample ratio to align with VAE.
- Each onset interval "start-end" (in seconds) is converted to embedding indices via (time * 24000 / 480).
"""
line_onset_index = torch.zeros((4, latent_length)) # max for 4 events
line_event = []
event_idx = 0
for event_onset in line_onset_str.split('--'):
#print(event_onset)
(event, instance) = event_onset.split('__')
#print(instance)
line_event.append(event)
for start_end in instance.split('_'):
(start, end) = start_end.split('-')
start, end = int(float(start)*24000/480), int(float(end)*24000/480)
if end > (latent_length - 1): break
line_onset_index[event_idx, start: end] = 1
event_idx = event_idx + 1
return line_onset_index, line_event
class ContentEncoder(nn.Module):
"""
ContentEncoder encodes TCC and TDC information.
"""
def __init__(
self,
text_encoder: nn.Module= None,
):
super().__init__()
self.text_encoder = text_encoder
self.pool = nn.AdaptiveAvgPool1d(1)
def encode_content(
self, batch_content: list[Any], device: str | torch.device
):
batch_output = []
batch_mask = []
batch_onset = []
length_list = []
print(batch_content)
for content in batch_content:
caption = content["caption"]
onset = content["onset"]
length = int(float(content["length"]) *24000/480)
# Replacement for AudioTime
print(onset)
replace_label = content.get("replace_label", "False")
if replace_label == "True":
caption, onset = replace_event_synonyms(caption, onset)
# Handle random onset case for read data without timestamp
if content["onset"] == "random":
length_list.append(length)
"""
fixed embedding. Actually it's a sick sentence, a error during training, kept to match the checkpoint.
You can change it to sentence that difference to captions in datasets.
The use of fixed text to obtain encoding is for numerical stability.
We attempted to use learnable unified encoding during training, but the results were not satisfactory.
"""
event = "There is no event here"
event_embed = self.text_encoder([event.replace("_", " ")])["output"]
event_embed = self.pool(event_embed.permute(0, 2, 1)) # (B, 1024, 1)
event_embed = event_embed.flatten().unsqueeze(0)
new_onset = event_embed.repeat(length, 1).T
else:
onset_matrix, events = decode_data(onset, length)
length_list.append(length)
new_onset = torch.zeros((1024, length), device=device) # 1024 for T5
# TDC
for (idx, event) in enumerate(events):
with torch.no_grad():
event_embed = self.text_encoder([event.replace("_", " ")])["output"]
event_embed = self.pool(event_embed.permute(0, 2, 1)) # (B, 1024, 1)
event_embed = event_embed.flatten().unsqueeze(0)
mask = (onset_matrix[idx, :] == 0)
cols = mask.nonzero(as_tuple=True)[0]
new_onset[:, cols] += event_embed.T.float()
# TCC
output_dict = self.text_encoder([caption])
batch_output.append(output_dict["output"][0])
batch_mask.append(output_dict["mask"][0])
batch_onset.append(new_onset)
# Pad all sequences in the batch to the same length for batching
batch_output = nn.utils.rnn.pad_sequence(
batch_output, batch_first=True, padding_value=0
)
batch_mask = nn.utils.rnn.pad_sequence(
batch_mask, batch_first=True, padding_value=False
)
batch_onset = nn.utils.rnn.pad_sequence(
batch_onset, batch_first=True, padding_value=0
)
return batch_output, batch_mask, batch_onset, length_list