omnivinci / media.py
leoye's picture
Initial commit
fd01e7c
raw
history blame
22 kB
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import time
import random
import os
import tempfile
from collections import defaultdict
from io import BytesIO
from typing import Any, Dict, List, Optional, Union
import io
import cv2
import kaldiio
import librosa
import soundfile as sf
import torch
import numpy as np
import PIL
import PIL.Image
import requests
import tarfile
import whisper
import decord
from decord import AudioReader, cpu
from transformers import PretrainedConfig
MEDIA_TOKENS = {
"image": "<image>",
"video": "<vila/video>",
"speech": "<speech>",
"sound": "<sound>",
}
class Media:
"""Base class for media objects."""
pass
class File(Media):
"""File-based media object."""
def __init__(self, path: str) -> None:
self.path = path
class Image(File):
"""Image media object."""
pass
class Video(File):
"""Video media object."""
pass
class Speech(File):
"""Speech audio media object."""
def __init__(self, path, extension: str = None) -> None:
self.path = path
self.extension = extension
class Sound(File):
"""Sound/music audio media object."""
def __init__(self, path, extension: str = None) -> None:
self.path = path
self.extension = extension
def make_list(obj: Any) -> List:
"""Convert object to list if not already a list."""
return obj if isinstance(obj, list) else [obj]
def _extract_image(image: Union[Image, PIL.Image.Image]) -> PIL.Image.Image:
"""Extract PIL Image from Image object or return PIL Image as-is."""
if isinstance(image, Image):
if image.path.startswith("http://") or image.path.startswith("https://"):
image = PIL.Image.open(requests.get(image.path, stream=True).raw)
else:
image = PIL.Image.open(image.path)
return image
def _load_video_bytesio(
video_bytesio: BytesIO, *, num_frames: int, config: PretrainedConfig, load_aud: bool = False
) -> List[PIL.Image.Image]:
"""Load video from BytesIO object by writing to temporary file."""
with tempfile.NamedTemporaryFile(delete=True, suffix=".mp4") as temp_video:
temp_video.write(video_bytesio.read())
temp_video_name = temp_video.name
return _load_video(temp_video_name, num_frames=num_frames, load_aud=load_aud, config=config)
def get_overlap(inp1, inp2):
"""
Calculates the overlapping time frame between a video clip and an audio segment.
Args:
inp1 (list): [start_sec, end_sec]
inp2 (list): [start_sec, end_sec]
Returns:
tuple or None: (overlap_start, overlap_end) if overlap exists, else None.
"""
# Calculate the maximum start time and minimum end time
overlap_start = max(inp1[0], inp2[0])
overlap_end = min(inp1[1], inp2[1])
# Check if there is an actual overlap
if overlap_start < overlap_end:
return (overlap_start, overlap_end)
else:
return None
def _load_video(
video_path: str, *, num_frames: int, config: PretrainedConfig, load_aud: bool = False
) -> List[PIL.Image.Image]:
# Load video frames from a directory
if os.path.isdir(video_path):
frame_paths = sorted(glob.glob(os.path.join(video_path, "*")))
indices = np.round(np.linspace(0, len(frame_paths) - 1, num_frames)).astype(int)
return [PIL.Image.open(frame_paths[index]) for index in indices]
# Load video frames from a video file
vidcap = cv2.VideoCapture(video_path)
# Load audio if available and needed
audio_info = None
if load_aud:
try:
aud_feature, audio_info = _load_speech(video_path, config)
except Exception as e:
aud_feature = None
else:
aud_feature = None
# Find the last frame as frame count might not be accurate
frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
while frame_count > 0:
vidcap.set(cv2.CAP_PROP_POS_FRAMES, frame_count - 1)
if vidcap.grab():
break
frame_count -= 1
else:
raise ValueError(f"Video '{video_path}' has no frames.")
# Extract frames uniformly
indices = np.round(np.linspace(0, frame_count - 1, num_frames)).astype(int)
fps = vidcap.get(cv2.CAP_PROP_FPS)
video_duration = frame_count / fps
# When load_audio_in_video and interleaved_vis_aud_in_video is True, we need to load frames for each video segment
if config.load_audio_in_video and config.interleaved_vis_aud_in_video and aud_feature is not None:
segment_duration = config.interleaved_video_segment_duration
if segment_duration == -1:
raise ValueError("video_segment_duration is not set")
segment_vis_indices_list = []
segment_aud_indices_list = []
segment_counts = np.ceil(video_duration / segment_duration).astype(int)
if type(aud_feature) == dict:
aud_feas = aud_feature["input_features"]
else:
aud_feas = aud_feature
audio_start_sec = audio_info['audio_start_sec']
audio_end_sec = audio_info['audio_end_sample_sec']
stft_frames_per_second = config.audio_sampling_rate // config.audio_hop_length
_idx = 0
aud_sample_start_idx = 0
for i in range(segment_counts):
end_frame = min((i+1) * segment_duration * fps, frame_count)
_indices = []
while _idx < len(indices) and indices[_idx] < end_frame and _idx < len(indices):
_indices.append(indices[_idx])
_idx += 1
segment_vis_indices_list.append(_indices)
clip_start_sec = i * segment_duration
clip_end_sec = min(clip_start_sec + segment_duration, video_duration)
# get the audio indices for the current clip
overlap = get_overlap([clip_start_sec, clip_end_sec], [audio_start_sec, audio_end_sec])
if overlap is not None:
aud_sample_end_idx = round((overlap[1] - audio_start_sec) * stft_frames_per_second)
segment_aud_indices_list.append([aud_sample_start_idx, aud_sample_end_idx])
aud_sample_start_idx = aud_sample_end_idx
else:
segment_aud_indices_list.append([])
frames = {}
frame_times = {}
for index in indices:
if index in frames:
continue
vidcap.set(cv2.CAP_PROP_POS_FRAMES, index)
success, frame = vidcap.read()
if not success:
print(f"Failed to read frame {index} from video '{video_path}'. Skipped.")
continue
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames[index] = PIL.Image.fromarray(frame)
frame_times[index] = index / fps
output_frames = [frames[index] for index in indices if index in frames]
output_frame_times = [frame_times[index] for index in indices if index in frame_times]
video_info = {}
if config.load_audio_in_video and config.interleaved_vis_aud_in_video and aud_feature is not None:
new_segment_vis_indices_list = []
processed_frame_index = 0
for i, segment_indices in enumerate(segment_vis_indices_list):
new_segment_vis_indices_list.append([])
for index in segment_indices:
if index in frames:
new_segment_vis_indices_list[-1].append(processed_frame_index)
processed_frame_index += 1
segment_vis_indices_list = new_segment_vis_indices_list
video_info["segment_vis_indices_list"] = segment_vis_indices_list
video_info["segment_aud_indices_list"] = segment_aud_indices_list
video_info['expected_frame_count'] = len(indices)
video_info['video_path'] = video_path
if audio_info is not None:
audio_info['video_path'] = video_path
video_info['has_audio'] = aud_feature is not None
video_info['video_duration'] = video_duration
video_info['audio_info'] = audio_info
# calculate the time of each frame
video_info['video_frame_times'] = output_frame_times
return output_frames, aud_feature, video_info
def _extract_video(video: Video, config: PretrainedConfig) -> List[PIL.Image.Image]:
num_frames = config.num_video_frames
aud_fea = None
if getattr(config, "fps") != 0:
print("Extracting frames from video with specified FPS is not supported yet. Ignored.")
if isinstance(video.path, BytesIO):
frames, aud_fea, video_info = _load_video_bytesio(
video.path, num_frames=num_frames, config=config, load_aud=config.load_audio_in_video
)
else:
frames, aud_fea, video_info = _load_video(
video.path, num_frames=num_frames, config=config, load_aud=config.load_audio_in_video
)
if config.load_audio_in_video:
return frames, aud_fea, video_info
else:
return frames, video_info
def soundFile_read_audio(audio_file, offset=None, duration=None, dtype='float32'):
if dtype not in ['int32', 'float32']:
print("audio dtype must be int32 or float32. Default to float32")
dtype = 'float32'
# return read audio and its sample rate
if isinstance(audio_file, bytes):
audio_file = io.BytesIO(audio_file)
with sf.SoundFile(audio_file, 'r') as f:
sample_rate = f.samplerate
if offset is not None and offset > 0:
f.seek(int(offset * sample_rate))
if duration is not None and duration > 0:
samples = f.read(int(duration * sample_rate), dtype=dtype)
else:
samples = f.read(dtype=dtype)
return samples, sample_rate
def load_audio_from_tar(tar_file, audio_file):
with tarfile.open(tar_file, 'r') as tar:
audio_member = tar.getmember(audio_file)
audio_file = tar.extractfile(audio_member)
return librosa.load(audio_file)
def _load_audio_file(audio_path: str, config: PretrainedConfig):
# Load video frames from a directory
if audio_path is None:
return None
dirname = os.path.dirname(audio_path)
filename = os.path.basename(audio_path)
if dirname.endswith(".tar"):
speech, sample_rate = load_audio_from_tar(dirname, filename)
else:
sample_rate = config.audio_sampling_rate
speech = whisper.load_audio(audio_path, sr=sample_rate)
return speech, sample_rate
def _load_audio(audio: Union[str, dict], config: PretrainedConfig):
if isinstance(audio, str):
return _load_audio_file(audio, config)
elif isinstance(audio, dict):
audio_sample = audio['sample']
if isinstance(audio_sample, (bytes, io.BytesIO)):
offset = audio.get('offset', None)
duration = audio.get('duration', None)
dtype = audio.get('dtype', 'float32')
return soundFile_read_audio(
audio_sample, offset=offset, duration=duration, dtype=dtype
)
elif isinstance(audio_sample, np.ndarray):
return audio_sample, audio.get('sample_rate')
else:
raise ValueError(f"Expect the loaded audio to be a processed numpy array or raw bytes. Got {type(audio_sample)}")
else:
raise ValueError(f"Expect input to be a path string or dict. Got {type(audio)}")
def _whisper_process(audio, sample_rate, audio_chunk_length, max_chunks_per_file):
outputs = []
num_audio_chunks = 0
chunk_length = audio_chunk_length * sample_rate
for i in range(0, len(audio), chunk_length):
chunk = audio[i : i + chunk_length]
chunk = whisper.pad_or_trim(chunk)
if chunk.dtype != np.float32:
chunk = chunk.astype(np.float32)
mel = whisper.log_mel_spectrogram(chunk, n_mels=128)
num_audio_chunks+=1
outputs.append(mel)
if num_audio_chunks == max_chunks_per_file:
break
frames = torch.stack(outputs, dim=0)
return frames.numpy().tolist()
def _load_speech(speech, config: PretrainedConfig):
if type(speech) == str:
speech_path = speech
else:
speech_path = speech.path
# Load video frames from a directory
if speech_path is None:
return None
speech_outputs = []
if config.audio_chunk_length and not (type(config.audio_chunk_length) == str and "max" in config.audio_chunk_length):
try:
config.audio_chunk_length = int(config.audio_chunk_length)
except Exception as e:
print(f"Error setting audio_chunk_length: {e}")
raise e
audio_n_samples_limit = config.audio_chunk_length * config.audio_sampling_rate
def load_wav(speech_path):
speech, sr = librosa.load(speech_path, sr=config.audio_sampling_rate)
cur_max_length = speech.shape[0]
ori_audio_duration = cur_max_length / sr
return speech, ori_audio_duration
def get_audio(speech, audio_n_samples):
if type(speech) == decord.audio_reader.AudioReader:
ori_n_samples = speech.shape[1]
else:
ori_n_samples = speech.shape[0]
# random audio smaple
audio_start_sample_id = 0
audio_end_sample_id = ori_n_samples
load_max_audio = type(config.audio_chunk_length) == str and "max" in config.audio_chunk_length
if hasattr(config, 'random_audio_sample') and not load_max_audio:
if ori_n_samples > audio_n_samples:
audio_start_sample_id = random.randint(0, ori_n_samples - audio_n_samples)
audio_end_sample_id = audio_start_sample_id + audio_n_samples
else:
if load_max_audio:
if "_" in config.audio_chunk_length:
max_audio_chunk_length = int(config.audio_chunk_length.split("_")[1])
max_audio_n_samples = max_audio_chunk_length * config.audio_sampling_rate
audio_n_samples = min(ori_n_samples, max_audio_n_samples)
audio_end_sample_id = audio_n_samples
else:
audio_n_samples = ori_n_samples
audio_end_sample_id = audio_n_samples
else:
audio_end_sample_id = min(audio_n_samples, ori_n_samples)
if type(speech) == decord.audio_reader.AudioReader:
speech = speech[audio_start_sample_id:audio_end_sample_id].asnumpy()[0]
else:
speech = speech[audio_start_sample_id:audio_end_sample_id]
return speech, audio_n_samples, audio_start_sample_id, audio_end_sample_id
if isinstance(speech_path, dict):
if "offset" in speech_path:
speech, ori_sample_rate = _load_audio(speech_path, config)
else:
speech = speech_path["sample"]
ori_sample_rate = speech_path["sample_rate"]
# resample the speech based on current sample rate
speech = librosa.resample(speech, orig_sr=ori_sample_rate, target_sr=config.audio_sampling_rate)
# variable audio sequence lengths
ori_audio_duration = speech.shape[0] / config.audio_sampling_rate
speech, audio_n_samples, audio_start_sample_id, audio_end_sample_id = get_audio(speech, audio_n_samples_limit)
elif isinstance(speech_path, BytesIO):
if speech.extension == ".wav":
# speech, sr = librosa.load(speech_path, sr=config.audio_sampling_rate)
# ori_audio_duration = speech.shape[0] / sr
speech, ori_audio_duration = load_wav(speech_path)
speech, audio_n_samples, audio_start_sample_id, audio_end_sample_id = get_audio(speech, audio_n_samples_limit)
else:
raise ValueError(f"Unsupported audio extension: {speech.extension}")
elif ".mat" in speech_path or ".ark" in speech_path:
rate, speech = kaldiio.load_mat(speech_path)
speech = librosa.resample(speech, orig_sr=rate, target_sr=config.audio_sampling_rate)
speech, audio_n_samples, audio_start_sample_id, audio_end_sample_id = get_audio(speech, audio_n_samples_limit)
ori_audio_duration = speech.shape[0] / config.audio_sampling_rate
elif ".mp4" in speech_path:
# Load audio from video file
ar = AudioReader(speech_path, ctx=cpu(0), sample_rate=config.audio_sampling_rate, mono=True)
cur_max_length = ar.shape[1]
ori_audio_duration = cur_max_length / config.audio_sampling_rate
speech, audio_n_samples, audio_start_sample_id, audio_end_sample_id = get_audio(ar, audio_n_samples_limit)
else:
assert os.path.exists(speech_path), f"File {speech_path} does not exist"
speech, ori_audio_duration = load_wav(speech_path)
speech, audio_n_samples, audio_start_sample_id, audio_end_sample_id = get_audio(speech, audio_n_samples_limit)
# convert to float
speech = speech.astype(np.float32)
audio_n_samples = int(np.ceil(speech.shape[0] / (config.audio_sampling_rate * 30)) * (config.audio_sampling_rate * 30))
speech = whisper.pad_or_trim(speech, length=audio_n_samples) # we don't pad or trim here, instead, we pad based on the max length of all audio samples in the batch size later
new_audio_chunk_length = int(audio_n_samples // config.audio_sampling_rate)
audio_start_sec = audio_start_sample_id / config.audio_sampling_rate
audio_end_sample_sec = audio_end_sample_id / config.audio_sampling_rate
audio_info = {}
audio_info['new_audio_chunk_length'] = new_audio_chunk_length
audio_info['new_audio_n_samples'] = audio_n_samples
audio_info['ori_audio_duration'] = ori_audio_duration
audio_info['audio_start_sec'] = audio_start_sec
audio_info['audio_end_sample_sec'] = audio_end_sample_sec
return speech, audio_info
def _extract_speech(speech: Speech, config: PretrainedConfig):
frames, audio_info = _load_speech(speech, config)
return frames, audio_info
_extract_sound = _extract_speech
def extract_media(
messages: List[Dict[str, Any]],
config: Optional[PretrainedConfig] = None,
draft: bool = False,
) -> Dict[str, List[Any]]:
media = defaultdict(list)
if not hasattr(config, "load_audio_in_video"):
print(f"Warning: load_audio_in_video not in config, set to False")
config.load_audio_in_video = False
for message in messages:
text = ""
for part in make_list(message["value"]):
if isinstance(part, str):
for token in MEDIA_TOKENS.values():
if token in part:
print(f"Media token '{token}' found in text: '{part}'. Removed.")
part = part.replace(token, "").strip()
text += part
elif isinstance(part, (Image, PIL.Image.Image)):
if draft:
media["image"].append(part)
else:
media["image"].append(_extract_image(part))
text += MEDIA_TOKENS["image"]
elif isinstance(part, Video):
if draft:
media["video"].append(part)
else:
if config.load_audio_in_video:
output, aud_fea, video_info = _extract_video(part, config)
media["video"].append(output)
media["video_info"].append(video_info)
if aud_fea is not None:
media["sound"].append(aud_fea)
media["audio_info"].append(video_info['audio_info'])
text += MEDIA_TOKENS["sound"]
else:
output, video_info = _extract_video(part, config)
media["video"].append(output)
media["video_info"].append(video_info)
text += MEDIA_TOKENS["video"]
elif isinstance(part, Speech):
if draft:
if config.unified_audio_encoder:
media["sound"].append(part)
text += MEDIA_TOKENS["sound"]
else:
media["speech"].append(part)
text += MEDIA_TOKENS["speech"]
else:
output, audio_info = _extract_speech(part, config)
if output is not None:
if config.unified_audio_encoder:
media["sound"].append(output)
text += MEDIA_TOKENS["sound"]
else:
media["speech"].append(output)
text += MEDIA_TOKENS["speech"]
media["audio_info"].append(audio_info)
elif isinstance(part, Sound):
if draft:
media["sound"].append(part)
text += MEDIA_TOKENS["sound"]
else:
output, audio_info = _extract_sound(part, config)
if output is not None:
media["sound"].append(output)
media["audio_info"].append(audio_info)
text += MEDIA_TOKENS["sound"]
else:
print(f"part: {part}")
raise ValueError(f"Unsupported prompt part type: {type(part)}")
message["value"] = text
return media