|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import glob |
|
|
import time |
|
|
import random |
|
|
import os |
|
|
import tempfile |
|
|
from collections import defaultdict |
|
|
from io import BytesIO |
|
|
from typing import Any, Dict, List, Optional, Union |
|
|
import io |
|
|
import cv2 |
|
|
import kaldiio |
|
|
import librosa |
|
|
import soundfile as sf |
|
|
import torch |
|
|
import numpy as np |
|
|
import PIL |
|
|
import PIL.Image |
|
|
import requests |
|
|
import tarfile |
|
|
import whisper |
|
|
import decord |
|
|
from decord import AudioReader, cpu |
|
|
|
|
|
from transformers import PretrainedConfig |
|
|
|
|
|
MEDIA_TOKENS = { |
|
|
"image": "<image>", |
|
|
"video": "<vila/video>", |
|
|
"speech": "<speech>", |
|
|
"sound": "<sound>", |
|
|
} |
|
|
|
|
|
|
|
|
class Media: |
|
|
"""Base class for media objects.""" |
|
|
pass |
|
|
|
|
|
|
|
|
class File(Media): |
|
|
"""File-based media object.""" |
|
|
def __init__(self, path: str) -> None: |
|
|
self.path = path |
|
|
|
|
|
|
|
|
class Image(File): |
|
|
"""Image media object.""" |
|
|
pass |
|
|
|
|
|
|
|
|
class Video(File): |
|
|
"""Video media object.""" |
|
|
pass |
|
|
|
|
|
|
|
|
class Speech(File): |
|
|
"""Speech audio media object.""" |
|
|
def __init__(self, path, extension: str = None) -> None: |
|
|
self.path = path |
|
|
self.extension = extension |
|
|
|
|
|
|
|
|
class Sound(File): |
|
|
"""Sound/music audio media object.""" |
|
|
def __init__(self, path, extension: str = None) -> None: |
|
|
self.path = path |
|
|
self.extension = extension |
|
|
|
|
|
|
|
|
def make_list(obj: Any) -> List: |
|
|
"""Convert object to list if not already a list.""" |
|
|
return obj if isinstance(obj, list) else [obj] |
|
|
|
|
|
|
|
|
def _extract_image(image: Union[Image, PIL.Image.Image]) -> PIL.Image.Image: |
|
|
"""Extract PIL Image from Image object or return PIL Image as-is.""" |
|
|
if isinstance(image, Image): |
|
|
if image.path.startswith("http://") or image.path.startswith("https://"): |
|
|
image = PIL.Image.open(requests.get(image.path, stream=True).raw) |
|
|
else: |
|
|
image = PIL.Image.open(image.path) |
|
|
return image |
|
|
|
|
|
|
|
|
def _load_video_bytesio( |
|
|
video_bytesio: BytesIO, *, num_frames: int, config: PretrainedConfig, load_aud: bool = False |
|
|
) -> List[PIL.Image.Image]: |
|
|
"""Load video from BytesIO object by writing to temporary file.""" |
|
|
with tempfile.NamedTemporaryFile(delete=True, suffix=".mp4") as temp_video: |
|
|
temp_video.write(video_bytesio.read()) |
|
|
temp_video_name = temp_video.name |
|
|
return _load_video(temp_video_name, num_frames=num_frames, load_aud=load_aud, config=config) |
|
|
|
|
|
def get_overlap(inp1, inp2): |
|
|
""" |
|
|
Calculates the overlapping time frame between a video clip and an audio segment. |
|
|
|
|
|
Args: |
|
|
inp1 (list): [start_sec, end_sec] |
|
|
inp2 (list): [start_sec, end_sec] |
|
|
|
|
|
Returns: |
|
|
tuple or None: (overlap_start, overlap_end) if overlap exists, else None. |
|
|
""" |
|
|
|
|
|
overlap_start = max(inp1[0], inp2[0]) |
|
|
overlap_end = min(inp1[1], inp2[1]) |
|
|
|
|
|
|
|
|
if overlap_start < overlap_end: |
|
|
return (overlap_start, overlap_end) |
|
|
else: |
|
|
return None |
|
|
|
|
|
|
|
|
def _load_video( |
|
|
video_path: str, *, num_frames: int, config: PretrainedConfig, load_aud: bool = False |
|
|
) -> List[PIL.Image.Image]: |
|
|
|
|
|
if os.path.isdir(video_path): |
|
|
frame_paths = sorted(glob.glob(os.path.join(video_path, "*"))) |
|
|
indices = np.round(np.linspace(0, len(frame_paths) - 1, num_frames)).astype(int) |
|
|
return [PIL.Image.open(frame_paths[index]) for index in indices] |
|
|
|
|
|
|
|
|
vidcap = cv2.VideoCapture(video_path) |
|
|
|
|
|
|
|
|
audio_info = None |
|
|
if load_aud: |
|
|
try: |
|
|
aud_feature, audio_info = _load_speech(video_path, config) |
|
|
except Exception as e: |
|
|
aud_feature = None |
|
|
else: |
|
|
aud_feature = None |
|
|
|
|
|
|
|
|
frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
while frame_count > 0: |
|
|
vidcap.set(cv2.CAP_PROP_POS_FRAMES, frame_count - 1) |
|
|
if vidcap.grab(): |
|
|
break |
|
|
frame_count -= 1 |
|
|
else: |
|
|
raise ValueError(f"Video '{video_path}' has no frames.") |
|
|
|
|
|
|
|
|
indices = np.round(np.linspace(0, frame_count - 1, num_frames)).astype(int) |
|
|
|
|
|
fps = vidcap.get(cv2.CAP_PROP_FPS) |
|
|
video_duration = frame_count / fps |
|
|
|
|
|
|
|
|
if config.load_audio_in_video and config.interleaved_vis_aud_in_video and aud_feature is not None: |
|
|
segment_duration = config.interleaved_video_segment_duration |
|
|
if segment_duration == -1: |
|
|
raise ValueError("video_segment_duration is not set") |
|
|
|
|
|
segment_vis_indices_list = [] |
|
|
segment_aud_indices_list = [] |
|
|
segment_counts = np.ceil(video_duration / segment_duration).astype(int) |
|
|
|
|
|
if type(aud_feature) == dict: |
|
|
aud_feas = aud_feature["input_features"] |
|
|
else: |
|
|
aud_feas = aud_feature |
|
|
audio_start_sec = audio_info['audio_start_sec'] |
|
|
audio_end_sec = audio_info['audio_end_sample_sec'] |
|
|
|
|
|
stft_frames_per_second = config.audio_sampling_rate // config.audio_hop_length |
|
|
|
|
|
_idx = 0 |
|
|
aud_sample_start_idx = 0 |
|
|
for i in range(segment_counts): |
|
|
end_frame = min((i+1) * segment_duration * fps, frame_count) |
|
|
|
|
|
_indices = [] |
|
|
while _idx < len(indices) and indices[_idx] < end_frame and _idx < len(indices): |
|
|
_indices.append(indices[_idx]) |
|
|
_idx += 1 |
|
|
segment_vis_indices_list.append(_indices) |
|
|
clip_start_sec = i * segment_duration |
|
|
clip_end_sec = min(clip_start_sec + segment_duration, video_duration) |
|
|
|
|
|
|
|
|
overlap = get_overlap([clip_start_sec, clip_end_sec], [audio_start_sec, audio_end_sec]) |
|
|
if overlap is not None: |
|
|
aud_sample_end_idx = round((overlap[1] - audio_start_sec) * stft_frames_per_second) |
|
|
segment_aud_indices_list.append([aud_sample_start_idx, aud_sample_end_idx]) |
|
|
aud_sample_start_idx = aud_sample_end_idx |
|
|
else: |
|
|
segment_aud_indices_list.append([]) |
|
|
frames = {} |
|
|
frame_times = {} |
|
|
for index in indices: |
|
|
if index in frames: |
|
|
continue |
|
|
vidcap.set(cv2.CAP_PROP_POS_FRAMES, index) |
|
|
success, frame = vidcap.read() |
|
|
if not success: |
|
|
print(f"Failed to read frame {index} from video '{video_path}'. Skipped.") |
|
|
continue |
|
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
|
frames[index] = PIL.Image.fromarray(frame) |
|
|
frame_times[index] = index / fps |
|
|
|
|
|
output_frames = [frames[index] for index in indices if index in frames] |
|
|
output_frame_times = [frame_times[index] for index in indices if index in frame_times] |
|
|
|
|
|
video_info = {} |
|
|
if config.load_audio_in_video and config.interleaved_vis_aud_in_video and aud_feature is not None: |
|
|
new_segment_vis_indices_list = [] |
|
|
processed_frame_index = 0 |
|
|
for i, segment_indices in enumerate(segment_vis_indices_list): |
|
|
new_segment_vis_indices_list.append([]) |
|
|
for index in segment_indices: |
|
|
if index in frames: |
|
|
new_segment_vis_indices_list[-1].append(processed_frame_index) |
|
|
processed_frame_index += 1 |
|
|
segment_vis_indices_list = new_segment_vis_indices_list |
|
|
|
|
|
video_info["segment_vis_indices_list"] = segment_vis_indices_list |
|
|
video_info["segment_aud_indices_list"] = segment_aud_indices_list |
|
|
video_info['expected_frame_count'] = len(indices) |
|
|
video_info['video_path'] = video_path |
|
|
if audio_info is not None: |
|
|
audio_info['video_path'] = video_path |
|
|
video_info['has_audio'] = aud_feature is not None |
|
|
video_info['video_duration'] = video_duration |
|
|
video_info['audio_info'] = audio_info |
|
|
|
|
|
|
|
|
video_info['video_frame_times'] = output_frame_times |
|
|
|
|
|
return output_frames, aud_feature, video_info |
|
|
|
|
|
|
|
|
def _extract_video(video: Video, config: PretrainedConfig) -> List[PIL.Image.Image]: |
|
|
num_frames = config.num_video_frames |
|
|
aud_fea = None |
|
|
|
|
|
if getattr(config, "fps") != 0: |
|
|
print("Extracting frames from video with specified FPS is not supported yet. Ignored.") |
|
|
|
|
|
if isinstance(video.path, BytesIO): |
|
|
frames, aud_fea, video_info = _load_video_bytesio( |
|
|
video.path, num_frames=num_frames, config=config, load_aud=config.load_audio_in_video |
|
|
) |
|
|
else: |
|
|
frames, aud_fea, video_info = _load_video( |
|
|
video.path, num_frames=num_frames, config=config, load_aud=config.load_audio_in_video |
|
|
) |
|
|
|
|
|
if config.load_audio_in_video: |
|
|
return frames, aud_fea, video_info |
|
|
else: |
|
|
return frames, video_info |
|
|
|
|
|
|
|
|
def soundFile_read_audio(audio_file, offset=None, duration=None, dtype='float32'): |
|
|
if dtype not in ['int32', 'float32']: |
|
|
print("audio dtype must be int32 or float32. Default to float32") |
|
|
dtype = 'float32' |
|
|
|
|
|
if isinstance(audio_file, bytes): |
|
|
audio_file = io.BytesIO(audio_file) |
|
|
with sf.SoundFile(audio_file, 'r') as f: |
|
|
sample_rate = f.samplerate |
|
|
if offset is not None and offset > 0: |
|
|
f.seek(int(offset * sample_rate)) |
|
|
if duration is not None and duration > 0: |
|
|
samples = f.read(int(duration * sample_rate), dtype=dtype) |
|
|
else: |
|
|
samples = f.read(dtype=dtype) |
|
|
return samples, sample_rate |
|
|
|
|
|
def load_audio_from_tar(tar_file, audio_file): |
|
|
with tarfile.open(tar_file, 'r') as tar: |
|
|
audio_member = tar.getmember(audio_file) |
|
|
audio_file = tar.extractfile(audio_member) |
|
|
return librosa.load(audio_file) |
|
|
|
|
|
def _load_audio_file(audio_path: str, config: PretrainedConfig): |
|
|
|
|
|
if audio_path is None: |
|
|
return None |
|
|
|
|
|
dirname = os.path.dirname(audio_path) |
|
|
filename = os.path.basename(audio_path) |
|
|
|
|
|
if dirname.endswith(".tar"): |
|
|
speech, sample_rate = load_audio_from_tar(dirname, filename) |
|
|
else: |
|
|
sample_rate = config.audio_sampling_rate |
|
|
speech = whisper.load_audio(audio_path, sr=sample_rate) |
|
|
|
|
|
return speech, sample_rate |
|
|
|
|
|
|
|
|
def _load_audio(audio: Union[str, dict], config: PretrainedConfig): |
|
|
if isinstance(audio, str): |
|
|
return _load_audio_file(audio, config) |
|
|
elif isinstance(audio, dict): |
|
|
audio_sample = audio['sample'] |
|
|
if isinstance(audio_sample, (bytes, io.BytesIO)): |
|
|
offset = audio.get('offset', None) |
|
|
duration = audio.get('duration', None) |
|
|
dtype = audio.get('dtype', 'float32') |
|
|
return soundFile_read_audio( |
|
|
audio_sample, offset=offset, duration=duration, dtype=dtype |
|
|
) |
|
|
elif isinstance(audio_sample, np.ndarray): |
|
|
return audio_sample, audio.get('sample_rate') |
|
|
else: |
|
|
raise ValueError(f"Expect the loaded audio to be a processed numpy array or raw bytes. Got {type(audio_sample)}") |
|
|
else: |
|
|
raise ValueError(f"Expect input to be a path string or dict. Got {type(audio)}") |
|
|
|
|
|
def _whisper_process(audio, sample_rate, audio_chunk_length, max_chunks_per_file): |
|
|
outputs = [] |
|
|
num_audio_chunks = 0 |
|
|
|
|
|
chunk_length = audio_chunk_length * sample_rate |
|
|
for i in range(0, len(audio), chunk_length): |
|
|
chunk = audio[i : i + chunk_length] |
|
|
chunk = whisper.pad_or_trim(chunk) |
|
|
if chunk.dtype != np.float32: |
|
|
chunk = chunk.astype(np.float32) |
|
|
mel = whisper.log_mel_spectrogram(chunk, n_mels=128) |
|
|
num_audio_chunks+=1 |
|
|
outputs.append(mel) |
|
|
if num_audio_chunks == max_chunks_per_file: |
|
|
break |
|
|
|
|
|
frames = torch.stack(outputs, dim=0) |
|
|
return frames.numpy().tolist() |
|
|
|
|
|
def _load_speech(speech, config: PretrainedConfig): |
|
|
if type(speech) == str: |
|
|
speech_path = speech |
|
|
else: |
|
|
speech_path = speech.path |
|
|
|
|
|
|
|
|
if speech_path is None: |
|
|
return None |
|
|
speech_outputs = [] |
|
|
|
|
|
if config.audio_chunk_length and not (type(config.audio_chunk_length) == str and "max" in config.audio_chunk_length): |
|
|
try: |
|
|
config.audio_chunk_length = int(config.audio_chunk_length) |
|
|
except Exception as e: |
|
|
print(f"Error setting audio_chunk_length: {e}") |
|
|
raise e |
|
|
|
|
|
audio_n_samples_limit = config.audio_chunk_length * config.audio_sampling_rate |
|
|
|
|
|
def load_wav(speech_path): |
|
|
speech, sr = librosa.load(speech_path, sr=config.audio_sampling_rate) |
|
|
cur_max_length = speech.shape[0] |
|
|
ori_audio_duration = cur_max_length / sr |
|
|
return speech, ori_audio_duration |
|
|
|
|
|
def get_audio(speech, audio_n_samples): |
|
|
|
|
|
if type(speech) == decord.audio_reader.AudioReader: |
|
|
ori_n_samples = speech.shape[1] |
|
|
else: |
|
|
ori_n_samples = speech.shape[0] |
|
|
|
|
|
|
|
|
audio_start_sample_id = 0 |
|
|
audio_end_sample_id = ori_n_samples |
|
|
|
|
|
|
|
|
load_max_audio = type(config.audio_chunk_length) == str and "max" in config.audio_chunk_length |
|
|
if hasattr(config, 'random_audio_sample') and not load_max_audio: |
|
|
if ori_n_samples > audio_n_samples: |
|
|
audio_start_sample_id = random.randint(0, ori_n_samples - audio_n_samples) |
|
|
audio_end_sample_id = audio_start_sample_id + audio_n_samples |
|
|
else: |
|
|
if load_max_audio: |
|
|
if "_" in config.audio_chunk_length: |
|
|
max_audio_chunk_length = int(config.audio_chunk_length.split("_")[1]) |
|
|
max_audio_n_samples = max_audio_chunk_length * config.audio_sampling_rate |
|
|
audio_n_samples = min(ori_n_samples, max_audio_n_samples) |
|
|
audio_end_sample_id = audio_n_samples |
|
|
else: |
|
|
audio_n_samples = ori_n_samples |
|
|
audio_end_sample_id = audio_n_samples |
|
|
else: |
|
|
audio_end_sample_id = min(audio_n_samples, ori_n_samples) |
|
|
|
|
|
if type(speech) == decord.audio_reader.AudioReader: |
|
|
speech = speech[audio_start_sample_id:audio_end_sample_id].asnumpy()[0] |
|
|
else: |
|
|
speech = speech[audio_start_sample_id:audio_end_sample_id] |
|
|
|
|
|
|
|
|
return speech, audio_n_samples, audio_start_sample_id, audio_end_sample_id |
|
|
|
|
|
if isinstance(speech_path, dict): |
|
|
if "offset" in speech_path: |
|
|
speech, ori_sample_rate = _load_audio(speech_path, config) |
|
|
|
|
|
else: |
|
|
speech = speech_path["sample"] |
|
|
ori_sample_rate = speech_path["sample_rate"] |
|
|
|
|
|
|
|
|
speech = librosa.resample(speech, orig_sr=ori_sample_rate, target_sr=config.audio_sampling_rate) |
|
|
|
|
|
ori_audio_duration = speech.shape[0] / config.audio_sampling_rate |
|
|
speech, audio_n_samples, audio_start_sample_id, audio_end_sample_id = get_audio(speech, audio_n_samples_limit) |
|
|
|
|
|
elif isinstance(speech_path, BytesIO): |
|
|
if speech.extension == ".wav": |
|
|
|
|
|
|
|
|
speech, ori_audio_duration = load_wav(speech_path) |
|
|
speech, audio_n_samples, audio_start_sample_id, audio_end_sample_id = get_audio(speech, audio_n_samples_limit) |
|
|
else: |
|
|
raise ValueError(f"Unsupported audio extension: {speech.extension}") |
|
|
|
|
|
elif ".mat" in speech_path or ".ark" in speech_path: |
|
|
rate, speech = kaldiio.load_mat(speech_path) |
|
|
speech = librosa.resample(speech, orig_sr=rate, target_sr=config.audio_sampling_rate) |
|
|
speech, audio_n_samples, audio_start_sample_id, audio_end_sample_id = get_audio(speech, audio_n_samples_limit) |
|
|
ori_audio_duration = speech.shape[0] / config.audio_sampling_rate |
|
|
elif ".mp4" in speech_path: |
|
|
|
|
|
ar = AudioReader(speech_path, ctx=cpu(0), sample_rate=config.audio_sampling_rate, mono=True) |
|
|
cur_max_length = ar.shape[1] |
|
|
ori_audio_duration = cur_max_length / config.audio_sampling_rate |
|
|
speech, audio_n_samples, audio_start_sample_id, audio_end_sample_id = get_audio(ar, audio_n_samples_limit) |
|
|
else: |
|
|
assert os.path.exists(speech_path), f"File {speech_path} does not exist" |
|
|
speech, ori_audio_duration = load_wav(speech_path) |
|
|
speech, audio_n_samples, audio_start_sample_id, audio_end_sample_id = get_audio(speech, audio_n_samples_limit) |
|
|
|
|
|
|
|
|
speech = speech.astype(np.float32) |
|
|
audio_n_samples = int(np.ceil(speech.shape[0] / (config.audio_sampling_rate * 30)) * (config.audio_sampling_rate * 30)) |
|
|
|
|
|
speech = whisper.pad_or_trim(speech, length=audio_n_samples) |
|
|
|
|
|
new_audio_chunk_length = int(audio_n_samples // config.audio_sampling_rate) |
|
|
audio_start_sec = audio_start_sample_id / config.audio_sampling_rate |
|
|
audio_end_sample_sec = audio_end_sample_id / config.audio_sampling_rate |
|
|
|
|
|
audio_info = {} |
|
|
audio_info['new_audio_chunk_length'] = new_audio_chunk_length |
|
|
audio_info['new_audio_n_samples'] = audio_n_samples |
|
|
audio_info['ori_audio_duration'] = ori_audio_duration |
|
|
audio_info['audio_start_sec'] = audio_start_sec |
|
|
audio_info['audio_end_sample_sec'] = audio_end_sample_sec |
|
|
|
|
|
return speech, audio_info |
|
|
|
|
|
def _extract_speech(speech: Speech, config: PretrainedConfig): |
|
|
frames, audio_info = _load_speech(speech, config) |
|
|
return frames, audio_info |
|
|
|
|
|
_extract_sound = _extract_speech |
|
|
def extract_media( |
|
|
messages: List[Dict[str, Any]], |
|
|
config: Optional[PretrainedConfig] = None, |
|
|
draft: bool = False, |
|
|
) -> Dict[str, List[Any]]: |
|
|
media = defaultdict(list) |
|
|
|
|
|
if not hasattr(config, "load_audio_in_video"): |
|
|
print(f"Warning: load_audio_in_video not in config, set to False") |
|
|
config.load_audio_in_video = False |
|
|
|
|
|
for message in messages: |
|
|
text = "" |
|
|
for part in make_list(message["value"]): |
|
|
if isinstance(part, str): |
|
|
for token in MEDIA_TOKENS.values(): |
|
|
if token in part: |
|
|
print(f"Media token '{token}' found in text: '{part}'. Removed.") |
|
|
part = part.replace(token, "").strip() |
|
|
text += part |
|
|
elif isinstance(part, (Image, PIL.Image.Image)): |
|
|
if draft: |
|
|
media["image"].append(part) |
|
|
else: |
|
|
media["image"].append(_extract_image(part)) |
|
|
text += MEDIA_TOKENS["image"] |
|
|
elif isinstance(part, Video): |
|
|
if draft: |
|
|
media["video"].append(part) |
|
|
else: |
|
|
if config.load_audio_in_video: |
|
|
output, aud_fea, video_info = _extract_video(part, config) |
|
|
media["video"].append(output) |
|
|
media["video_info"].append(video_info) |
|
|
if aud_fea is not None: |
|
|
media["sound"].append(aud_fea) |
|
|
media["audio_info"].append(video_info['audio_info']) |
|
|
text += MEDIA_TOKENS["sound"] |
|
|
else: |
|
|
output, video_info = _extract_video(part, config) |
|
|
media["video"].append(output) |
|
|
media["video_info"].append(video_info) |
|
|
text += MEDIA_TOKENS["video"] |
|
|
elif isinstance(part, Speech): |
|
|
if draft: |
|
|
if config.unified_audio_encoder: |
|
|
media["sound"].append(part) |
|
|
text += MEDIA_TOKENS["sound"] |
|
|
else: |
|
|
media["speech"].append(part) |
|
|
text += MEDIA_TOKENS["speech"] |
|
|
else: |
|
|
output, audio_info = _extract_speech(part, config) |
|
|
if output is not None: |
|
|
if config.unified_audio_encoder: |
|
|
media["sound"].append(output) |
|
|
text += MEDIA_TOKENS["sound"] |
|
|
else: |
|
|
media["speech"].append(output) |
|
|
text += MEDIA_TOKENS["speech"] |
|
|
media["audio_info"].append(audio_info) |
|
|
elif isinstance(part, Sound): |
|
|
if draft: |
|
|
media["sound"].append(part) |
|
|
text += MEDIA_TOKENS["sound"] |
|
|
else: |
|
|
output, audio_info = _extract_sound(part, config) |
|
|
if output is not None: |
|
|
media["sound"].append(output) |
|
|
media["audio_info"].append(audio_info) |
|
|
text += MEDIA_TOKENS["sound"] |
|
|
else: |
|
|
print(f"part: {part}") |
|
|
raise ValueError(f"Unsupported prompt part type: {type(part)}") |
|
|
message["value"] = text |
|
|
return media |
|
|
|